Skip to content

Commit

Permalink
[RLlib] Add on_workers_recreated callback to Algorithm. (ray-projec…
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Oct 20, 2023
1 parent 17e6cc2 commit 23eb7f7
Show file tree
Hide file tree
Showing 11 changed files with 186 additions and 39 deletions.
17 changes: 9 additions & 8 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -965,12 +965,12 @@ py_test(
)

# A3C
py_test(
name = "test_a3c",
tags = ["team:rllib", "algorithms_dir"],
size = "large",
srcs = ["algorithms/a3c/tests/test_a3c.py"]
)
# py_test(
# name = "test_a3c",
# tags = ["team:rllib", "algorithms_dir"],
# size = "large",
# srcs = ["algorithms/a3c/tests/test_a3c.py"]
# )

# AlphaStar
py_test(
Expand Down Expand Up @@ -4478,14 +4478,15 @@ py_test(
# --------------------------------------------------------------------
py_test_module_list(
files = [
"tests/test_dnc.py",
"tests/test_perf.py",
"algorithms/a3c/tests/test_a3c.py",
"env/wrappers/tests/test_kaggle_wrapper.py",
"examples/env/tests/test_cliff_walking_wall_env.py",
"examples/env/tests/test_coin_game_non_vectorized_env.py",
"examples/env/tests/test_coin_game_vectorized_env.py",
"examples/env/tests/test_matrix_sequential_social_dilemma.py",
"examples/env/tests/test_wrappers.py",
"tests/test_dnc.py",
"tests/test_perf.py",
"utils/tests/test_utils.py",
],
size = "large",
Expand Down
16 changes: 13 additions & 3 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,11 +1355,13 @@ def remote_fn(worker):

@OverrideToImplementCustomLogic
@DeveloperAPI
def restore_workers(self, workers: WorkerSet):
"""Try to restore failed workers if necessary.
def restore_workers(self, workers: WorkerSet) -> None:
"""Try syncing previously failed and restarted workers with local, if necessary.
Algorithms that use custom RolloutWorkers may override this method to
disable default, and create custom restoration logics.
disable default, and create custom restoration logics. Note that "restoring"
does not include the actual restarting process, but merely what should happen
after such a restart of a (previously failed) worker.
Args:
workers: The WorkerSet to restore. This may be Rollout or Evaluation
Expand Down Expand Up @@ -1397,6 +1399,14 @@ def restore_workers(self, workers: WorkerSet):
mark_healthy=True,
)

# Fire the callback for re-created workers.
self.callbacks.on_workers_recreated(
algorithm=self,
worker_set=workers,
worker_ids=restored,
is_evaluation=workers.local_worker().config.in_evaluation,
)

@OverrideToImplementCustomLogic
@DeveloperAPI
def training_step(self) -> ResultDict:
Expand Down
11 changes: 11 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,17 @@ def validate(self) -> None:
else:
_torch, _ = try_import_torch()

# Can not use "tf" with learner API.
if self.framework_str == "tf" and (
self._enable_rl_module_api or self._enable_learner_api
):
raise ValueError(
"Cannot use `framework=tf` with new API stack! Either do "
"`config.framework('tf2')` OR set both `config.rl_module("
"_enable_rl_module_api=False)` and `config.training("
"_enable_learner_api=False)`."
)

# Check if torch framework supports torch.compile.
if (
_torch is not None
Expand Down
73 changes: 71 additions & 2 deletions rllib/algorithms/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.evaluation import RolloutWorker
from ray.rllib.evaluation import RolloutWorker, WorkerSet


@PublicAPI
Expand Down Expand Up @@ -75,6 +75,70 @@ def on_algorithm_init(
"""
pass

@OverrideToImplementCustomLogic
def on_workers_recreated(
self,
*,
algorithm: "Algorithm",
worker_set: "WorkerSet",
worker_ids: List[int],
is_evaluation: bool,
**kwargs,
) -> None:
"""Callback run after one or more workers have been recreated.
You can access (and change) the worker(s) in question via the following code
snippet inside your custom override of this method:
Note that any "worker" inside the algorithm's `self.worker` and
`self.evaluation_workers` WorkerSets are instances of a subclass of EnvRunner.
.. testcode::
from ray.rllib.algorithms.callbacks import DefaultCallbacks
class MyCallbacks(DefaultCallbacks):
def on_workers_recreated(
self,
*,
algorithm,
worker_set,
worker_ids,
is_evaluation,
**kwargs,
):
# Define what you would like to do on the recreated
# workers:
def func(w):
# Here, we just set some arbitrary property to 1.
if is_evaluation:
w._custom_property_for_evaluation = 1
else:
w._custom_property_for_training = 1
# Use the `foreach_workers` method of the worker set and
# only loop through those worker IDs that have been restarted.
# Note that we set `local_worker=False` to NOT include it (local
# workers are never recreated; if they fail, the entire Algorithm
# fails).
worker_set.foreach_worker(
func,
remote_worker_ids=worker_ids,
local_worker=False,
)
Args:
algorithm: Reference to the Algorithm instance.
worker_set: The WorkerSet object in which the workers in question reside.
You can use a `worker_set.foreach_worker(remote_worker_ids=...,
local_worker=False)` method call to execute custom
code on the recreated (remote) workers. Note that the local worker is
never recreated as a failure of this would also crash the Algorithm.
worker_ids: The list of (remote) worker IDs that have been recreated.
is_evaluation: Whether `worker_set` is the evaluation WorkerSet (located
in `Algorithm.evaluation_workers`) or not.
"""
pass

@OverrideToImplementCustomLogic
def on_checkpoint_loaded(
self,
Expand All @@ -98,7 +162,7 @@ def on_create_policy(self, *, policy_id: PolicyID, policy: Policy) -> None:
Args:
policy_id: ID of the newly created policy.
policy: the policy just created.
policy: The policy just created.
"""
pass

Expand Down Expand Up @@ -494,6 +558,11 @@ def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None:
for callback in self._callback_list:
callback.on_algorithm_init(algorithm=algorithm, **kwargs)

@override(DefaultCallbacks)
def on_workers_recreated(self, **kwargs) -> None:
for callback in self._callback_list:
callback.on_workers_recreated(**kwargs)

@override(DefaultCallbacks)
def on_checkpoint_loaded(self, *, algorithm: "Algorithm", **kwargs) -> None:
for callback in self._callback_list:
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self, algo_class=None):
# Add constructor kwargs here (if any).
}

# enable the rl module api by default
# Enable the rl module api by default.
self.rl_module(_enable_rl_module_api=True)
self.training(_enable_learner_api=True)

Expand Down
54 changes: 54 additions & 0 deletions rllib/algorithms/tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,40 @@
import unittest

import ray
from ray.rllib.algorithms.appo import APPOConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks
import ray.rllib.algorithms.dqn as dqn
from ray.rllib.algorithms.pg import PGConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.examples.env.cartpole_crashing import CartPoleCrashing
from ray.rllib.evaluation.episode import Episode
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.utils.test_utils import framework_iterator


class OnWorkerCreatedCallbacks(DefaultCallbacks):
def on_workers_recreated(
self,
*,
algorithm,
worker_set,
worker_ids,
is_evaluation,
**kwargs,
):
# Store in the algorithm object's counters the number of times, this worker
# (ID'd by index and whether eval or not) has been recreated/restarted.
for id_ in worker_ids:
key = f"{'eval_' if is_evaluation else ''}worker_{id_}_recreated"
# Increase the counter.
algorithm._counters[key] += 1
print(f"changed {key} to {algorithm._counters[key]}")

# Execute some dummy code on each of the recreated workers.
results = worker_set.foreach_worker(lambda w: w.ping())
print(results) # should print "pong" n times (one for each recreated worker).


class InitAndCheckpointRestoredCallbacks(DefaultCallbacks):
def on_algorithm_init(self, *, algorithm, **kwargs):
self._on_init_was_called = True
Expand Down Expand Up @@ -84,6 +109,35 @@ def setUpClass(cls):
def tearDownClass(cls):
ray.shutdown()

def test_on_workers_recreated_callback(self):
config = (
APPOConfig()
.environment(CartPoleCrashing)
.callbacks(OnWorkerCreatedCallbacks)
.rollouts(num_rollout_workers=2)
.fault_tolerance(recreate_failed_workers=True)
)

for _ in framework_iterator(config, frameworks=("tf2", "torch")):
algo = config.build()
original_worker_ids = algo.workers.healthy_worker_ids()
for id_ in original_worker_ids:
self.assertTrue(algo._counters[f"worker_{id_}_recreated"] == 0)

# After building the algorithm, we should have 2 healthy (remote) workers.
self.assertTrue(len(original_worker_ids) == 2)

# Train a bit (and have the envs/workers crash a couple of times).
for _ in range(3):
algo.train()

# After training, each new worker should have been recreated at least once.
new_worker_ids = algo.workers.healthy_worker_ids()
self.assertTrue(len(new_worker_ids) == 2)
for id_ in new_worker_ids:
self.assertTrue(algo._counters[f"worker_{id_}_recreated"] >= 1)
algo.stop()

def test_on_init_and_checkpoint_loaded(self):
config = (
PPOConfig()
Expand Down
37 changes: 18 additions & 19 deletions rllib/evaluation/worker_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
in the returned set as well (default: True). If `num_workers`
is 0, always create a local worker.
logdir: Optional logging directory for workers.
_setup: Whether to setup workers. This is only for testing.
_setup: Whether to actually set up workers. This is only for testing.
"""
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig

Expand Down Expand Up @@ -635,9 +635,9 @@ def foreach_worker(
self,
func: Callable[[RolloutWorker], T],
*,
local_worker=True,
local_worker: bool = True,
# TODO(jungong) : switch to True once Algorithm is migrated.
healthy_only=False,
healthy_only: bool = False,
remote_worker_ids: List[int] = None,
timeout_seconds: Optional[int] = None,
return_obj_refs: bool = False,
Expand All @@ -647,10 +647,9 @@ def foreach_worker(
Args:
func: The function to call for each worker (as only arg).
local_worker: Whether apply func on local worker too. Default is True.
healthy_only: Apply func on known active workers only. By default
this will apply func on all workers regardless of their states.
remote_worker_ids: Apply func on a selected set of remote workers.
local_worker: Whether apply `func` on local worker too. Default is True.
healthy_only: Apply `func` on known-to-be healthy workers only.
remote_worker_ids: Apply `func` on a selected set of remote workers.
timeout_seconds: Time to wait for results. Default is None.
return_obj_refs: whether to return ObjectRef instead of actual results.
Note, for fault tolerance reasons, these returned ObjectRefs should
Expand Down Expand Up @@ -689,20 +688,19 @@ def foreach_worker_with_id(
self,
func: Callable[[int, RolloutWorker], T],
*,
local_worker=True,
local_worker: bool = True,
# TODO(jungong) : switch to True once Algorithm is migrated.
healthy_only=False,
healthy_only: bool = False,
remote_worker_ids: List[int] = None,
timeout_seconds: Optional[int] = None,
) -> List[T]:
"""Similar to foreach_worker(), but calls the function with id of the worker too.
Args:
func: The function to call for each worker (as only arg).
local_worker: Whether apply func on local worker too. Default is True.
healthy_only: Apply func on known active workers only. By default
this will apply func on all workers regardless of their states.
remote_worker_ids: Apply func on a selected set of remote workers.
local_worker: Whether apply `func` on local worker too. Default is True.
healthy_only: Apply `func` on known-to-be healthy workers only.
remote_worker_ids: Apply `func` on a selected set of remote workers.
timeout_seconds: Time to wait for results. Default is None.
Returns:
Expand Down Expand Up @@ -736,7 +734,7 @@ def foreach_worker_async(
func: Callable[[RolloutWorker], T],
*,
# TODO(jungong) : switch to True once Algorithm is migrated.
healthy_only=False,
healthy_only: bool = False,
remote_worker_ids: List[int] = None,
) -> int:
"""Calls the given function asynchronously with each worker as the argument.
Expand All @@ -747,9 +745,8 @@ def foreach_worker_async(
Args:
func: The function to call for each worker (as only arg).
healthy_only: Apply func on known active workers only. By default
this will apply func on all workers regardless of their states.
remote_worker_ids: Apply func on a selected set of remote workers.
healthy_only: Apply `func` on known-to-be healthy workers only.
remote_worker_ids: Apply `func` on a selected set of remote workers.
Returns:
The number of async requests that are currently in-flight.
Expand All @@ -773,6 +770,7 @@ def fetch_ready_async_reqs(
Args:
timeout_seconds: Time to wait for results. Default is 0, meaning
those requests that are already ready.
return_obj_refs: Whether to return ObjectRef instead of actual results.
mark_healthy: Whether to mark the worker as healthy based on call results.
Returns:
Expand Down Expand Up @@ -888,15 +886,16 @@ def foreach_env_with_context(

@DeveloperAPI
def probe_unhealthy_workers(self) -> List[int]:
"""Checks the unhealth workers, and try restoring their states.
"""Checks for unhealthy workers and tries restoring their states.
Returns:
IDs of the workers that were restored.
List of IDs of the workers that were restored.
"""
return self.__worker_manager.probe_unhealthy_actors(
timeout_seconds=self._remote_config.worker_health_probe_timeout_s
)

# TODO (sven): Deprecate once ARS/ES have been moved to `rllib_contrib`.
@staticmethod
def _from_existing(
local_worker: RolloutWorker, remote_workers: List[ActorHandle] = None
Expand Down
1 change: 1 addition & 0 deletions rllib/examples/nested_action_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
),
"b": Box(-10.0, 10.0, (2,)),
"c": MultiDiscrete([3, 3]),
"d": Discrete(2),
}
),
},
Expand Down
Loading

0 comments on commit 23eb7f7

Please sign in to comment.