Skip to content

Commit

Permalink
[RLlib] Only sync policy weights on RolloutWorkers of those policies …
Browse files Browse the repository at this point in the history
…that were actually updated. (ray-project#29973)
  • Loading branch information
sven1977 authored Nov 3, 2022
1 parent 6414c90 commit b38f0c8
Show file tree
Hide file tree
Showing 12 changed files with 74 additions and 67 deletions.
2 changes: 2 additions & 0 deletions rllib/algorithms/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ def training_step(self) -> ResultDict:
global_vars = {
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
}
# Synch updated weights back to the workers
# (only those policies that are trainable).
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights(
policies=self.workers.local_worker().get_policies_to_train(),
Expand Down
7 changes: 5 additions & 2 deletions rllib/algorithms/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,12 @@ def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]:
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
}

# Synch updated weights back to the particular worker.
# Synch updated weights back to the particular worker
# (only those policies that are trainable).
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
weights = local_worker.get_weights(local_worker.get_policies_to_train())
weights = local_worker.get_weights(
policies=local_worker.get_policies_to_train()
)
worker.set_weights.remote(weights, global_vars)

# Update global vars of the local worker.
Expand Down
7 changes: 5 additions & 2 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,12 +1284,15 @@ def training_step(self) -> ResultDict:
train_results = multi_gpu_train_one_step(self, train_batch)

# Update weights and global_vars - after learning on the local worker - on all
# remote workers.
# remote workers (only those policies that were actually trained).
global_vars = {
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
}
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights(global_vars=global_vars)
self.workers.sync_weights(
policies=list(train_results.keys()),
global_vars=global_vars,
)

return train_results

Expand Down
36 changes: 21 additions & 15 deletions rllib/algorithms/apex_dqn/apex_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import platform
import random
from collections import defaultdict
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple

import ray
from ray._private.dict import merge_dicts
Expand Down Expand Up @@ -43,6 +43,7 @@
from ray.rllib.utils.typing import (
PartialAlgorithmConfigDict,
ResultDict,
SampleBatchType,
)
from ray.tune.trainable import Trainable
from ray.tune.execution.placement_groups import PlacementGroupFactory
Expand Down Expand Up @@ -402,18 +403,19 @@ def validate_config(self, config):
@override(DQN)
def training_step(self) -> ResultDict:
num_samples_ready_dict = self.get_samples_and_store_to_replay_buffers()
worker_samples_collected = defaultdict(int)
num_worker_samples_collected = defaultdict(int)

for worker, samples_infos in num_samples_ready_dict.items():
for samples_info in samples_infos:
self._counters[NUM_AGENT_STEPS_SAMPLED] += samples_info["agent_steps"]
self._counters[NUM_ENV_STEPS_SAMPLED] += samples_info["env_steps"]
worker_samples_collected[worker] += samples_info["agent_steps"]
num_worker_samples_collected[worker] += samples_info["agent_steps"]

# update the weights of the workers that returned samples
# only do this if there are remote workers (config["num_workers"] > 1)
# Update the weights of the workers that returned samples.
# Only do this if there are remote workers (config["num_workers"] > 1).
# Also, only update those policies that were actually trained.
if self.workers.remote_workers():
self.update_workers(worker_samples_collected)
self.update_workers(num_worker_samples_collected)

# Update target network every `target_network_update_freq` sample steps.
cur_ts = self._counters[
Expand All @@ -424,7 +426,7 @@ def training_step(self) -> ResultDict:
# trigger a sample from the replay actors and enqueue operation to the
# learner thread.
self.sample_from_replay_buffer_place_on_learner_queue_non_blocking(
worker_samples_collected
num_worker_samples_collected
)
self.update_replay_sample_priority()

Expand Down Expand Up @@ -488,9 +490,12 @@ def update_workers(self, _num_samples_ready: Dict[ActorHandle, int]) -> int:
max_steps_weight_sync_delay = self.config["optimizer"]["max_weight_sync_delay"]
# Update our local copy of the weights if the learner thread has updated
# the learner worker's weights
if self.learner_thread.weights_updated:
self.learner_thread.weights_updated = False
weights = self.workers.local_worker().get_weights()
policy_ids_updated = self.learner_thread.policy_ids_updated.copy()
self.learner_thread.policy_ids_updated.clear()
if policy_ids_updated:
weights = self.workers.local_worker().get_weights(
policies=policy_ids_updated
)
self.curr_learner_weights = ray.put(weights)

num_workers_updated = 0
Expand Down Expand Up @@ -523,21 +528,22 @@ def update_workers(self, _num_samples_ready: Dict[ActorHandle, int]) -> int:
return num_workers_updated

def sample_from_replay_buffer_place_on_learner_queue_non_blocking(
self, num_samples_collected: Dict[ActorHandle, int]
self, num_worker_samples_collected: Dict[ActorHandle, int]
) -> None:
"""Get samples from the replay buffer and place them on the learner queue.
Args:
num_samples_collected: A mapping from ActorHandle (RolloutWorker) to
num_worker_samples_collected: A mapping from ActorHandle (RolloutWorker) to
number of samples returned by the remote worker. This is used to
implement training intensity which is the concept of triggering a
certain amount of training based on the number of samples that have
been collected since the last time that training was triggered.
"""

def wait_on_replay_actors() -> None:
def wait_on_replay_actors() -> List[Tuple[ActorHandle, SampleBatchType]]:
"""Wait for the replay actors to finish sampling for timeout seconds.
If the timeout is None, then block on the actors indefinitely.
"""
_replay_samples_ready = self._replay_actor_manager.get_ready()
Expand All @@ -547,7 +553,7 @@ def wait_on_replay_actors() -> None:
replay_sample_batches.append((_replay_actor, _sample_batch))
return replay_sample_batches

num_samples_collected = sum(num_samples_collected.values())
num_samples_collected = sum(num_worker_samples_collected.values())
self.curr_num_samples_collected += num_samples_collected
replay_sample_batches = wait_on_replay_actors()
if self.curr_num_samples_collected >= self.config["train_batch_size"]:
Expand All @@ -564,7 +570,7 @@ def wait_on_replay_actors() -> None:
)
replay_sample_batches.extend(wait_on_replay_actors())

# add the sample batches to the learner queue
# Add all the tuples of (ActorHandle, SampleBatchType) to the learner queue.
for item in replay_sample_batches:
# Setting block = True prevents the learner thread,
# the main thread, and the gpu loader threads from
Expand Down
3 changes: 2 additions & 1 deletion rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,10 @@ def training_step(self) -> ResultDict:
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts

# Update remote workers's weights after learning on local worker
# (only those policies that were actually trained).
if self.workers.remote_workers():
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights()
self.workers.sync_weights(policies=list(train_results.keys()))

# Return all collected metrics for the iteration.
return train_results
Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/dqn/learner_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, local_worker):
self.grad_timer = _Timer()
self.overall_timer = _Timer()
self.daemon = True
self.weights_updated = False
self.policy_ids_updated = []
self.stopped = False
self.learner_info = {}

Expand All @@ -55,6 +55,7 @@ def step(self):
# minibatch SGD, tf vs torch).
learner_info_builder = LearnerInfoBuilder(num_devices=1)
multi_agent_results = self.local_worker.learn_on_batch(ma_batch)
self.policy_ids_updated.extend(list(multi_agent_results.keys()))
for pid, results in multi_agent_results.items():
learner_info_builder.add_learn_on_batch_results(results, pid)
td_error = results["td_error"]
Expand All @@ -75,7 +76,6 @@ def step(self):
(replay_actor, prio_dict, ma_batch.count, ma_batch.agent_steps())
)
self.learner_queue_size.push(self.inqueue.qsize())
self.weights_updated = True
self.overall_timer.push_units_processed(
ma_batch and ma_batch.count or 0
)
Expand Down
56 changes: 20 additions & 36 deletions rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentReplayBuffer
from ray.rllib.execution.common import (
_get_global_vars,
_get_shared_metrics,
)
from ray.rllib.execution.learner_thread import LearnerThread
from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
from ray.rllib.execution.parallel_requests import AsyncRequestsManager
Expand Down Expand Up @@ -45,6 +41,7 @@
from ray.rllib.utils.typing import (
AlgorithmConfigDict,
PartialAlgorithmConfigDict,
PolicyID,
ResultDict,
SampleBatchType,
T,
Expand Down Expand Up @@ -408,33 +405,6 @@ def gather_experiences_directly(workers, config):
return train_batches


# Update worker weights as they finish generating experiences.
class BroadcastUpdateLearnerWeights:
def __init__(self, learner_thread, workers, broadcast_interval):
self.learner_thread = learner_thread
self.steps_since_broadcast = 0
self.broadcast_interval = broadcast_interval
self.workers = workers
self.weights = workers.local_worker().get_weights()

def __call__(self, item):
actor, batch = item
self.steps_since_broadcast += 1
if (
self.steps_since_broadcast >= self.broadcast_interval
and self.learner_thread.weights_updated
):
self.weights = ray.put(self.workers.local_worker().get_weights())
self.steps_since_broadcast = 0
self.learner_thread.weights_updated = False
# Update metrics.
metrics = _get_shared_metrics()
metrics.counters["num_weight_broadcasts"] += 1
actor.set_weights.remote(self.weights, _get_global_vars())
# Also update global vars of the local worker.
self.workers.local_worker().set_global_vars(_get_global_vars())


class Impala(Algorithm):
"""Importance weighted actor/learner architecture (IMPALA) Algorithm
Expand Down Expand Up @@ -641,9 +611,9 @@ def training_step(self) -> ResultDict:
# Extract most recent train results from learner thread.
train_results = self.process_trained_results()

# Sync worker weights.
# Sync worker weights (only those policies that were actually updated).
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.update_workers_if_necessary()
self.update_workers_if_necessary(policy_ids=list(train_results.keys()))

return train_results

Expand Down Expand Up @@ -840,7 +810,21 @@ def process_experiences_tree_aggregation(

return ready_processed_batches

def update_workers_if_necessary(self) -> None:
def update_workers_if_necessary(
self,
policy_ids: Optional[List[PolicyID]] = None,
) -> None:
"""Updates all RolloutWorkers that require updating.
Updates only if NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS has been
reached and the worker has sent samples in this iteration. Also only updates
those policies, whose IDs are given via `policies` (if None, update all
policies).
Args:
policy_ids: Optional list of Policy IDs to update. If None, will update all
policies on the to-be-updated workers.
"""
# Only need to update workers if there are remote workers.
global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_TRAINED]}
self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS] += 1
Expand All @@ -850,9 +834,9 @@ def update_workers_if_necessary(self) -> None:
>= self.config["broadcast_interval"]
and self.workers_that_need_updates
):
weights = ray.put(self.workers.local_worker().get_weights())
weights = ray.put(self.workers.local_worker().get_weights(policy_ids))
self._learner_thread.policy_ids_updated.clear()
self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS] = 0
self._learner_thread.weights_updated = False
self._counters[NUM_SYNCH_WORKER_WEIGHTS] += 1

for worker in self.workers_that_need_updates:
Expand Down
6 changes: 4 additions & 2 deletions rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,12 @@ def training_step(self) -> ResultDict:
}

# Update weights - after learning on the local worker - on all remote
# workers.
# workers (only those policies that were actually trained).
if self.workers.remote_workers():
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights(global_vars=global_vars)
self.workers.sync_weights(
policies=list(train_results.keys()), global_vars=global_vars
)

# Update global vars on local worker as well.
self.workers.local_worker().set_global_vars(global_vars)
Expand Down
7 changes: 5 additions & 2 deletions rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,12 @@ def training_step(self) -> ResultDict:
# workers.
if self.workers.remote_workers():
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights(global_vars=global_vars)
self.workers.sync_weights(
policies=list(train_results.keys()),
global_vars=global_vars,
)

# For each policy: update KL scale and warn about possible issues
# For each policy: Update KL scale and warn about possible issues
for policy_id, policy_info in train_results.items():
# Update KL loss with dynamic scaling
# for each (possibly multiagent) policy we are training
Expand Down
7 changes: 5 additions & 2 deletions rllib/algorithms/simple_q/simple_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,12 @@ def training_step(self) -> ResultDict:
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts

# Update weights and global_vars - after learning on the local worker -
# on all remote workers.
# on all remote workers (only those policies that were actually trained).
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights(global_vars=global_vars)
self.workers.sync_weights(
policies=list(train_results.keys()),
global_vars=global_vars,
)
else:
train_results = {}

Expand Down
4 changes: 2 additions & 2 deletions rllib/execution/learner_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
self.load_timer = _Timer()
self.load_wait_timer = _Timer()
self.daemon = True
self.weights_updated = False
self.policy_ids_updated = []
self.learner_info = {}
self.stopped = False
self.num_steps = 0
Expand All @@ -87,10 +87,10 @@ def step(self) -> Optional[_NextValueNotReady]:
# tf vs torch).
learner_info_builder = LearnerInfoBuilder(num_devices=1)
multi_agent_results = self.local_worker.learn_on_batch(batch)
self.policy_ids_updated.extend(list(multi_agent_results.keys()))
for pid, results in multi_agent_results.items():
learner_info_builder.add_learn_on_batch_results(results, pid)
self.learner_info = learner_info_builder.finalize()
self.weights_updated = True

self.num_steps += 1
# Put tuple: env-steps, agent-steps, and learner info into the queue.
Expand Down
2 changes: 1 addition & 1 deletion rllib/execution/multi_gpu_learner_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def step(self) -> None:
offset=0, buffer_index=buffer_idx
)
learner_info_builder.add_learn_on_batch_results(default_policy_results)
self.weights_updated = True
self.policy_ids_updated.append(pid)
get_num_samples_loaded_into_buffer += (
policy.get_num_samples_loaded_into_buffer(buffer_idx)
)
Expand Down

0 comments on commit b38f0c8

Please sign in to comment.