From b38f0c8454e68d00a3a029c4cefa6b87ea3a8c1d Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 3 Nov 2022 18:07:31 +0100 Subject: [PATCH] [RLlib] Only sync policy weights on RolloutWorkers of those policies that were actually updated. (#29973) --- rllib/algorithms/a2c/a2c.py | 2 + rllib/algorithms/a3c/a3c.py | 7 ++- rllib/algorithms/algorithm.py | 7 ++- rllib/algorithms/apex_dqn/apex_dqn.py | 36 +++++++------ rllib/algorithms/cql/cql.py | 3 +- rllib/algorithms/dqn/learner_thread.py | 4 +- rllib/algorithms/impala/impala.py | 56 ++++++++------------- rllib/algorithms/marwil/marwil.py | 6 ++- rllib/algorithms/ppo/ppo.py | 7 ++- rllib/algorithms/simple_q/simple_q.py | 7 ++- rllib/execution/learner_thread.py | 4 +- rllib/execution/multi_gpu_learner_thread.py | 2 +- 12 files changed, 74 insertions(+), 67 deletions(-) diff --git a/rllib/algorithms/a2c/a2c.py b/rllib/algorithms/a2c/a2c.py index b484a2cf9b669..429a26c9963da 100644 --- a/rllib/algorithms/a2c/a2c.py +++ b/rllib/algorithms/a2c/a2c.py @@ -242,6 +242,8 @@ def training_step(self) -> ResultDict: global_vars = { "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], } + # Synch updated weights back to the workers + # (only those policies that are trainable). with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights( policies=self.workers.local_worker().get_policies_to_train(), diff --git a/rllib/algorithms/a3c/a3c.py b/rllib/algorithms/a3c/a3c.py index b3acdc298ddb8..39e399af47929 100644 --- a/rllib/algorithms/a3c/a3c.py +++ b/rllib/algorithms/a3c/a3c.py @@ -245,9 +245,12 @@ def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]: "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], } - # Synch updated weights back to the particular worker. + # Synch updated weights back to the particular worker + # (only those policies that are trainable). with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: - weights = local_worker.get_weights(local_worker.get_policies_to_train()) + weights = local_worker.get_weights( + policies=local_worker.get_policies_to_train() + ) worker.set_weights.remote(weights, global_vars) # Update global vars of the local worker. diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index d775ba9d1f061..ee71fdf82c578 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -1284,12 +1284,15 @@ def training_step(self) -> ResultDict: train_results = multi_gpu_train_one_step(self, train_batch) # Update weights and global_vars - after learning on the local worker - on all - # remote workers. + # remote workers (only those policies that were actually trained). global_vars = { "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], } with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: - self.workers.sync_weights(global_vars=global_vars) + self.workers.sync_weights( + policies=list(train_results.keys()), + global_vars=global_vars, + ) return train_results diff --git a/rllib/algorithms/apex_dqn/apex_dqn.py b/rllib/algorithms/apex_dqn/apex_dqn.py index b3ed83e9b9c9f..9ed42fb34e32b 100644 --- a/rllib/algorithms/apex_dqn/apex_dqn.py +++ b/rllib/algorithms/apex_dqn/apex_dqn.py @@ -15,7 +15,7 @@ import platform import random from collections import defaultdict -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import ray from ray._private.dict import merge_dicts @@ -43,6 +43,7 @@ from ray.rllib.utils.typing import ( PartialAlgorithmConfigDict, ResultDict, + SampleBatchType, ) from ray.tune.trainable import Trainable from ray.tune.execution.placement_groups import PlacementGroupFactory @@ -402,18 +403,19 @@ def validate_config(self, config): @override(DQN) def training_step(self) -> ResultDict: num_samples_ready_dict = self.get_samples_and_store_to_replay_buffers() - worker_samples_collected = defaultdict(int) + num_worker_samples_collected = defaultdict(int) for worker, samples_infos in num_samples_ready_dict.items(): for samples_info in samples_infos: self._counters[NUM_AGENT_STEPS_SAMPLED] += samples_info["agent_steps"] self._counters[NUM_ENV_STEPS_SAMPLED] += samples_info["env_steps"] - worker_samples_collected[worker] += samples_info["agent_steps"] + num_worker_samples_collected[worker] += samples_info["agent_steps"] - # update the weights of the workers that returned samples - # only do this if there are remote workers (config["num_workers"] > 1) + # Update the weights of the workers that returned samples. + # Only do this if there are remote workers (config["num_workers"] > 1). + # Also, only update those policies that were actually trained. if self.workers.remote_workers(): - self.update_workers(worker_samples_collected) + self.update_workers(num_worker_samples_collected) # Update target network every `target_network_update_freq` sample steps. cur_ts = self._counters[ @@ -424,7 +426,7 @@ def training_step(self) -> ResultDict: # trigger a sample from the replay actors and enqueue operation to the # learner thread. self.sample_from_replay_buffer_place_on_learner_queue_non_blocking( - worker_samples_collected + num_worker_samples_collected ) self.update_replay_sample_priority() @@ -488,9 +490,12 @@ def update_workers(self, _num_samples_ready: Dict[ActorHandle, int]) -> int: max_steps_weight_sync_delay = self.config["optimizer"]["max_weight_sync_delay"] # Update our local copy of the weights if the learner thread has updated # the learner worker's weights - if self.learner_thread.weights_updated: - self.learner_thread.weights_updated = False - weights = self.workers.local_worker().get_weights() + policy_ids_updated = self.learner_thread.policy_ids_updated.copy() + self.learner_thread.policy_ids_updated.clear() + if policy_ids_updated: + weights = self.workers.local_worker().get_weights( + policies=policy_ids_updated + ) self.curr_learner_weights = ray.put(weights) num_workers_updated = 0 @@ -523,12 +528,12 @@ def update_workers(self, _num_samples_ready: Dict[ActorHandle, int]) -> int: return num_workers_updated def sample_from_replay_buffer_place_on_learner_queue_non_blocking( - self, num_samples_collected: Dict[ActorHandle, int] + self, num_worker_samples_collected: Dict[ActorHandle, int] ) -> None: """Get samples from the replay buffer and place them on the learner queue. Args: - num_samples_collected: A mapping from ActorHandle (RolloutWorker) to + num_worker_samples_collected: A mapping from ActorHandle (RolloutWorker) to number of samples returned by the remote worker. This is used to implement training intensity which is the concept of triggering a certain amount of training based on the number of samples that have @@ -536,8 +541,9 @@ def sample_from_replay_buffer_place_on_learner_queue_non_blocking( """ - def wait_on_replay_actors() -> None: + def wait_on_replay_actors() -> List[Tuple[ActorHandle, SampleBatchType]]: """Wait for the replay actors to finish sampling for timeout seconds. + If the timeout is None, then block on the actors indefinitely. """ _replay_samples_ready = self._replay_actor_manager.get_ready() @@ -547,7 +553,7 @@ def wait_on_replay_actors() -> None: replay_sample_batches.append((_replay_actor, _sample_batch)) return replay_sample_batches - num_samples_collected = sum(num_samples_collected.values()) + num_samples_collected = sum(num_worker_samples_collected.values()) self.curr_num_samples_collected += num_samples_collected replay_sample_batches = wait_on_replay_actors() if self.curr_num_samples_collected >= self.config["train_batch_size"]: @@ -564,7 +570,7 @@ def wait_on_replay_actors() -> None: ) replay_sample_batches.extend(wait_on_replay_actors()) - # add the sample batches to the learner queue + # Add all the tuples of (ActorHandle, SampleBatchType) to the learner queue. for item in replay_sample_batches: # Setting block = True prevents the learner thread, # the main thread, and the gpu loader threads from diff --git a/rllib/algorithms/cql/cql.py b/rllib/algorithms/cql/cql.py index 70ce34fbdc407..fe354dd7413bf 100644 --- a/rllib/algorithms/cql/cql.py +++ b/rllib/algorithms/cql/cql.py @@ -204,9 +204,10 @@ def training_step(self) -> ResultDict: self._counters[LAST_TARGET_UPDATE_TS] = cur_ts # Update remote workers's weights after learning on local worker + # (only those policies that were actually trained). if self.workers.remote_workers(): with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: - self.workers.sync_weights() + self.workers.sync_weights(policies=list(train_results.keys())) # Return all collected metrics for the iteration. return train_results diff --git a/rllib/algorithms/dqn/learner_thread.py b/rllib/algorithms/dqn/learner_thread.py index 168d703b380ef..2fbc8752e1702 100644 --- a/rllib/algorithms/dqn/learner_thread.py +++ b/rllib/algorithms/dqn/learner_thread.py @@ -30,7 +30,7 @@ def __init__(self, local_worker): self.grad_timer = _Timer() self.overall_timer = _Timer() self.daemon = True - self.weights_updated = False + self.policy_ids_updated = [] self.stopped = False self.learner_info = {} @@ -55,6 +55,7 @@ def step(self): # minibatch SGD, tf vs torch). learner_info_builder = LearnerInfoBuilder(num_devices=1) multi_agent_results = self.local_worker.learn_on_batch(ma_batch) + self.policy_ids_updated.extend(list(multi_agent_results.keys())) for pid, results in multi_agent_results.items(): learner_info_builder.add_learn_on_batch_results(results, pid) td_error = results["td_error"] @@ -75,7 +76,6 @@ def step(self): (replay_actor, prio_dict, ma_batch.count, ma_batch.agent_steps()) ) self.learner_queue_size.push(self.inqueue.qsize()) - self.weights_updated = True self.overall_timer.push_units_processed( ma_batch and ma_batch.count or 0 ) diff --git a/rllib/algorithms/impala/impala.py b/rllib/algorithms/impala/impala.py index 2373d8812a961..b349d166bd817 100644 --- a/rllib/algorithms/impala/impala.py +++ b/rllib/algorithms/impala/impala.py @@ -11,10 +11,6 @@ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentReplayBuffer -from ray.rllib.execution.common import ( - _get_global_vars, - _get_shared_metrics, -) from ray.rllib.execution.learner_thread import LearnerThread from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread from ray.rllib.execution.parallel_requests import AsyncRequestsManager @@ -45,6 +41,7 @@ from ray.rllib.utils.typing import ( AlgorithmConfigDict, PartialAlgorithmConfigDict, + PolicyID, ResultDict, SampleBatchType, T, @@ -408,33 +405,6 @@ def gather_experiences_directly(workers, config): return train_batches -# Update worker weights as they finish generating experiences. -class BroadcastUpdateLearnerWeights: - def __init__(self, learner_thread, workers, broadcast_interval): - self.learner_thread = learner_thread - self.steps_since_broadcast = 0 - self.broadcast_interval = broadcast_interval - self.workers = workers - self.weights = workers.local_worker().get_weights() - - def __call__(self, item): - actor, batch = item - self.steps_since_broadcast += 1 - if ( - self.steps_since_broadcast >= self.broadcast_interval - and self.learner_thread.weights_updated - ): - self.weights = ray.put(self.workers.local_worker().get_weights()) - self.steps_since_broadcast = 0 - self.learner_thread.weights_updated = False - # Update metrics. - metrics = _get_shared_metrics() - metrics.counters["num_weight_broadcasts"] += 1 - actor.set_weights.remote(self.weights, _get_global_vars()) - # Also update global vars of the local worker. - self.workers.local_worker().set_global_vars(_get_global_vars()) - - class Impala(Algorithm): """Importance weighted actor/learner architecture (IMPALA) Algorithm @@ -641,9 +611,9 @@ def training_step(self) -> ResultDict: # Extract most recent train results from learner thread. train_results = self.process_trained_results() - # Sync worker weights. + # Sync worker weights (only those policies that were actually updated). with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: - self.update_workers_if_necessary() + self.update_workers_if_necessary(policy_ids=list(train_results.keys())) return train_results @@ -840,7 +810,21 @@ def process_experiences_tree_aggregation( return ready_processed_batches - def update_workers_if_necessary(self) -> None: + def update_workers_if_necessary( + self, + policy_ids: Optional[List[PolicyID]] = None, + ) -> None: + """Updates all RolloutWorkers that require updating. + + Updates only if NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS has been + reached and the worker has sent samples in this iteration. Also only updates + those policies, whose IDs are given via `policies` (if None, update all + policies). + + Args: + policy_ids: Optional list of Policy IDs to update. If None, will update all + policies on the to-be-updated workers. + """ # Only need to update workers if there are remote workers. global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_TRAINED]} self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS] += 1 @@ -850,9 +834,9 @@ def update_workers_if_necessary(self) -> None: >= self.config["broadcast_interval"] and self.workers_that_need_updates ): - weights = ray.put(self.workers.local_worker().get_weights()) + weights = ray.put(self.workers.local_worker().get_weights(policy_ids)) + self._learner_thread.policy_ids_updated.clear() self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS] = 0 - self._learner_thread.weights_updated = False self._counters[NUM_SYNCH_WORKER_WEIGHTS] += 1 for worker in self.workers_that_need_updates: diff --git a/rllib/algorithms/marwil/marwil.py b/rllib/algorithms/marwil/marwil.py index 342d0b2e20fc4..b314ce02b81e7 100644 --- a/rllib/algorithms/marwil/marwil.py +++ b/rllib/algorithms/marwil/marwil.py @@ -260,10 +260,12 @@ def training_step(self) -> ResultDict: } # Update weights - after learning on the local worker - on all remote - # workers. + # workers (only those policies that were actually trained). if self.workers.remote_workers(): with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: - self.workers.sync_weights(global_vars=global_vars) + self.workers.sync_weights( + policies=list(train_results.keys()), global_vars=global_vars + ) # Update global vars on local worker as well. self.workers.local_worker().set_global_vars(global_vars) diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index 378c324a9e1d2..987a0c87f33ed 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -434,9 +434,12 @@ def training_step(self) -> ResultDict: # workers. if self.workers.remote_workers(): with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: - self.workers.sync_weights(global_vars=global_vars) + self.workers.sync_weights( + policies=list(train_results.keys()), + global_vars=global_vars, + ) - # For each policy: update KL scale and warn about possible issues + # For each policy: Update KL scale and warn about possible issues for policy_id, policy_info in train_results.items(): # Update KL loss with dynamic scaling # for each (possibly multiagent) policy we are training diff --git a/rllib/algorithms/simple_q/simple_q.py b/rllib/algorithms/simple_q/simple_q.py index f224e467f8814..9a33744dc94dd 100644 --- a/rllib/algorithms/simple_q/simple_q.py +++ b/rllib/algorithms/simple_q/simple_q.py @@ -385,9 +385,12 @@ def training_step(self) -> ResultDict: self._counters[LAST_TARGET_UPDATE_TS] = cur_ts # Update weights and global_vars - after learning on the local worker - - # on all remote workers. + # on all remote workers (only those policies that were actually trained). with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: - self.workers.sync_weights(global_vars=global_vars) + self.workers.sync_weights( + policies=list(train_results.keys()), + global_vars=global_vars, + ) else: train_results = {} diff --git a/rllib/execution/learner_thread.py b/rllib/execution/learner_thread.py index de79418d88774..fd3c0ea009c8f 100644 --- a/rllib/execution/learner_thread.py +++ b/rllib/execution/learner_thread.py @@ -61,7 +61,7 @@ def __init__( self.load_timer = _Timer() self.load_wait_timer = _Timer() self.daemon = True - self.weights_updated = False + self.policy_ids_updated = [] self.learner_info = {} self.stopped = False self.num_steps = 0 @@ -87,10 +87,10 @@ def step(self) -> Optional[_NextValueNotReady]: # tf vs torch). learner_info_builder = LearnerInfoBuilder(num_devices=1) multi_agent_results = self.local_worker.learn_on_batch(batch) + self.policy_ids_updated.extend(list(multi_agent_results.keys())) for pid, results in multi_agent_results.items(): learner_info_builder.add_learn_on_batch_results(results, pid) self.learner_info = learner_info_builder.finalize() - self.weights_updated = True self.num_steps += 1 # Put tuple: env-steps, agent-steps, and learner info into the queue. diff --git a/rllib/execution/multi_gpu_learner_thread.py b/rllib/execution/multi_gpu_learner_thread.py index 706c76853f98a..5d26d8c4031d3 100644 --- a/rllib/execution/multi_gpu_learner_thread.py +++ b/rllib/execution/multi_gpu_learner_thread.py @@ -163,7 +163,7 @@ def step(self) -> None: offset=0, buffer_index=buffer_idx ) learner_info_builder.add_learn_on_batch_results(default_policy_results) - self.weights_updated = True + self.policy_ids_updated.append(pid) get_num_samples_loaded_into_buffer += ( policy.get_num_samples_loaded_into_buffer(buffer_idx) )