Skip to content

Commit

Permalink
[RLlib] IMPALA/APPO multi-agent mix-in-buffer fixes (plus MA learning…
Browse files Browse the repository at this point in the history
… tests). (ray-project#25848)
  • Loading branch information
ArturNiederfahrenhorst authored Jun 17, 2022
1 parent 1c27469 commit a322cc5
Show file tree
Hide file tree
Showing 12 changed files with 257 additions and 130 deletions.
20 changes: 20 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,16 @@ py_test(
args = ["--yaml-dir=tuned_examples/appo"]
)

py_test(
name = "learning_tests_multi_agent_cartpole_appo",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/appo/multi-agent-cartpole-appo.yaml"],
args = ["--yaml-dir=tuned_examples/appo"]
)

# py_test(
# name = "learning_tests_frozenlake_appo",
# main = "tests/run_regression_tests.py",
Expand Down Expand Up @@ -402,6 +412,16 @@ py_test(
# args = ["--yaml-dir=tuned_examples/impala"]
# )

py_test(
name = "learning_tests_multi_agent_cartpole_impala",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/impala/multi-agent-cartpole-impala.yaml"],
args = ["--yaml-dir=tuned_examples/impala"]
)

py_test(
name = "learning_tests_cartpole_impala_fake_gpus",
main = "tests/run_regression_tests.py",
Expand Down
15 changes: 9 additions & 6 deletions rllib/algorithms/apex_ddpg/apex_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from ray.rllib.algorithms.ddpg.ddpg import DDPG, DDPGConfig
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import AlgorithmConfigDict
from ray.rllib.utils.typing import PartialAlgorithmConfigDict
from ray.rllib.utils.typing import ResultDict
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, Deprecated
from ray.rllib.utils.typing import (
AlgorithmConfigDict,
PartialAlgorithmConfigDict,
ResultDict,
)
from ray.util.iter import LocalIterator


Expand Down Expand Up @@ -196,8 +198,9 @@ def on_worker_failures(
removed_workers: removed worker ids.
new_workers: ids of newly created workers.
"""
self._sampling_actor_manager.remove_workers(removed_workers)
self._sampling_actor_manager.add_workers(new_workers)
if self.config["_disable_execution_plan_api"]:
self._sampling_actor_manager.remove_workers(removed_workers)
self._sampling_actor_manager.add_workers(new_workers)

@staticmethod
@override(DDPG)
Expand Down
5 changes: 3 additions & 2 deletions rllib/algorithms/apex_dqn/apex_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,8 +652,9 @@ def on_worker_failures(
removed_workers: removed worker ids.
new_workers: ids of newly created workers.
"""
self._sampling_actor_manager.remove_workers(removed_workers)
self._sampling_actor_manager.add_workers(new_workers)
if self.config["_disable_execution_plan_api"]:
self._sampling_actor_manager.remove_workers(removed_workers)
self._sampling_actor_manager.add_workers(new_workers)

@override(Algorithm)
def _compile_iteration_results(self, *, step_ctx, iteration_results=None):
Expand Down
48 changes: 2 additions & 46 deletions rllib/algorithms/appo/appo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Dict, List, Optional, Type, Union

import ray
from ray.rllib.algorithms.appo.utils import make_appo_models
from ray.rllib.algorithms.impala import vtrace_tf as vtrace
from ray.rllib.algorithms.impala.impala_tf_policy import (
_make_time_major,
Expand All @@ -32,11 +33,9 @@
KLCoeffMixin,
ValueNetworkMixin,
)
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
from ray.rllib.utils.annotations import (
DeveloperAPI,
override,
)
from ray.rllib.utils.framework import try_import_tf
Expand All @@ -45,52 +44,9 @@

tf1, tf, tfv = try_import_tf()

POLICY_SCOPE = "func"
TARGET_POLICY_SCOPE = "target_func"

logger = logging.getLogger(__name__)


@DeveloperAPI
def make_appo_model(policy) -> ModelV2:
"""Builds model and target model for APPO.
Returns:
ModelV2: The Model for the Policy to use.
Note: The target model will not be returned, just assigned to
`policy.target_model`.
"""
# Get the num_outputs for the following model construction calls.
_, logit_dim = ModelCatalog.get_action_dist(
policy.action_space, policy.config["model"]
)

# Construct the (main) model.
policy.model = ModelCatalog.get_model_v2(
policy.observation_space,
policy.action_space,
logit_dim,
policy.config["model"],
name=POLICY_SCOPE,
framework=policy.framework,
)
policy.model_variables = policy.model.variables()

# Construct the target model.
policy.target_model = ModelCatalog.get_model_v2(
policy.observation_space,
policy.action_space,
logit_dim,
policy.config["model"],
name=TARGET_POLICY_SCOPE,
framework=policy.framework,
)
policy.target_model_variables = policy.target_model.variables()

# Return only the model (not the target model).
return policy.model


class TargetNetworkMixin:
"""Target NN is updated by master learner via the `update_target` method.
Expand Down Expand Up @@ -182,7 +138,7 @@ def __init__(

@override(base)
def make_model(self) -> ModelV2:
return make_appo_model(self)
return make_appo_models(self)

@override(base)
def loss(
Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/appo/appo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Any, Dict, List, Optional, Type, Union

import ray
from ray.rllib.algorithms.appo.appo_tf_policy import make_appo_model
from ray.rllib.algorithms.appo.utils import make_appo_models
import ray.rllib.algorithms.impala.vtrace_torch as vtrace
from ray.rllib.algorithms.impala.impala_torch_policy import (
make_time_major,
Expand Down Expand Up @@ -101,7 +101,7 @@ def init_view_requirements(self):

@override(TorchPolicyV2)
def make_model(self) -> ModelV2:
return make_appo_model(self)
return make_appo_models(self)

@override(TorchPolicyV2)
def loss(
Expand Down
45 changes: 45 additions & 0 deletions rllib/algorithms/appo/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2


POLICY_SCOPE = "func"
TARGET_POLICY_SCOPE = "target_func"


def make_appo_models(policy) -> ModelV2:
"""Builds model and target model for APPO.
Returns:
ModelV2: The Model for the Policy to use.
Note: The target model will not be returned, just assigned to
`policy.target_model`.
"""
# Get the num_outputs for the following model construction calls.
_, logit_dim = ModelCatalog.get_action_dist(
policy.action_space, policy.config["model"]
)

# Construct the (main) model.
policy.model = ModelCatalog.get_model_v2(
policy.observation_space,
policy.action_space,
logit_dim,
policy.config["model"],
name=POLICY_SCOPE,
framework=policy.framework,
)
policy.model_variables = policy.model.variables()

# Construct the target model.
policy.target_model = ModelCatalog.get_model_v2(
policy.observation_space,
policy.action_space,
logit_dim,
policy.config["model"],
name=TARGET_POLICY_SCOPE,
framework=policy.framework,
)
policy.target_model_variables = policy.target_model.variables()

# Return only the model (not the target model).
return policy.model
55 changes: 27 additions & 28 deletions rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,54 @@
import copy
import logging
import platform

import queue
from typing import Optional, Type, List, Dict, Union, Callable, Any
from typing import Any, Callable, Dict, List, Optional, Type, Union

import ray
from ray.actor import ActorHandle
from ray.rllib import SampleBatch
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentReplayBuffer
from ray.rllib.execution.learner_thread import LearnerThread
from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
from ray.rllib.execution.parallel_requests import (
AsyncRequestsManager,
)
from ray.rllib.execution.tree_agg import gather_experiences_tree_aggregation
from ray.rllib.execution.common import (
STEPS_TRAINED_COUNTER,
STEPS_TRAINED_THIS_ITER_COUNTER,
_get_global_vars,
_get_shared_metrics,
)
from ray.rllib.execution.replay_ops import MixInReplay
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
from ray.rllib.execution.concurrency_ops import Concurrently, Dequeue, Enqueue
from ray.rllib.execution.learner_thread import LearnerThread
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
from ray.rllib.execution.parallel_requests import AsyncRequestsManager
from ray.rllib.execution.replay_ops import MixInReplay
from ray.rllib.execution.rollout_ops import ConcatBatches, ParallelRollouts
from ray.rllib.execution.tree_agg import gather_experiences_tree_aggregation
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.actors import create_colocated_actors
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import (
DEPRECATED_VALUE,
Deprecated,
deprecation_warning,
)
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
)
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ReplayMode
from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES

# from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
from ray.rllib.utils.typing import (
AlgorithmConfigDict,
PartialAlgorithmConfigDict,
ResultDict,
AlgorithmConfigDict,
SampleBatchType,
T,
)
from ray.rllib.utils.deprecation import (
Deprecated,
DEPRECATED_VALUE,
deprecation_warning,
)
from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.types import ObjectRef

Expand Down Expand Up @@ -470,9 +469,7 @@ def get_default_policy_class(
return A3CTorchPolicy
elif config["framework"] == "tf":
if config["vtrace"]:
from ray.rllib.algorithms.impala.impala_tf_policy import (
ImpalaTF1Policy,
)
from ray.rllib.algorithms.impala.impala_tf_policy import ImpalaTF1Policy

return ImpalaTF1Policy
else:
Expand Down Expand Up @@ -590,6 +587,7 @@ def setup(self, config: PartialAlgorithmConfigDict):
else 1
),
replay_ratio=self.config["replay_ratio"],
replay_mode=ReplayMode.LOCKSTEP,
)

self._sampling_actor_manager = AsyncRequestsManager(
Expand Down Expand Up @@ -658,7 +656,7 @@ def execution_plan(workers, config, **kwargs):
)

def record_steps_trained(item):
count, fetches = item
count, fetches, _ = item
metrics = _get_shared_metrics()
# Manually update the steps trained counter since the learner
# thread is executing outside the pipeline.
Expand Down Expand Up @@ -797,7 +795,7 @@ def place_processed_samples_on_learner_queue(self) -> None:

def process_trained_results(self) -> ResultDict:
# Get learner outputs/stats from output queue.
learner_infos = []
learner_info = copy.deepcopy(self._learner_thread.learner_info)
num_env_steps_trained = 0
num_agent_steps_trained = 0

Expand All @@ -811,10 +809,9 @@ def process_trained_results(self) -> ResultDict:
num_env_steps_trained += env_steps
num_agent_steps_trained += agent_steps
if learner_results:
learner_infos.append(learner_results)
learner_info.update(learner_results)
else:
raise RuntimeError("The learner thread died in while training")
learner_info = copy.deepcopy(self._learner_thread.learner_info)

# Update the steps trained counters.
self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = num_agent_steps_trained
Expand All @@ -839,7 +836,7 @@ def process_experiences_directly(
for batch in batches:
batch = batch.decompress_if_needed()
self.local_mixin_buffer.add_batch(batch)
batch = self.local_mixin_buffer.replay()
batch = self.local_mixin_buffer.replay(_ALL_POLICIES)
if batch:
processed_batches.append(batch)
return processed_batches
Expand Down Expand Up @@ -898,8 +895,9 @@ def on_worker_failures(
removed_workers: removed worker ids.
new_workers: ids of newly created workers.
"""
self._sampling_actor_manager.remove_workers(removed_workers)
self._sampling_actor_manager.add_workers(new_workers)
if self.config["_disable_execution_plan_api"]:
self._sampling_actor_manager.remove_workers(removed_workers)
self._sampling_actor_manager.add_workers(new_workers)

@override(Algorithm)
def _compile_iteration_results(self, *, step_ctx, iteration_results=None):
Expand All @@ -925,12 +923,13 @@ def __init__(self, config: AlgorithmConfigDict):
else 1
),
replay_ratio=self.config["replay_ratio"],
replay_mode=ReplayMode.LOCKSTEP,
)

def process_episodes(self, batch: SampleBatchType) -> SampleBatchType:
batch = batch.decompress_if_needed()
self._mixin_buffer.add_batch(batch)
processed_batches = self._mixin_buffer.replay()
processed_batches = self._mixin_buffer.replay(_ALL_POLICIES)
return processed_batches

def apply(
Expand Down
Loading

0 comments on commit a322cc5

Please sign in to comment.