Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing PPO to handle async sampling #34

Merged
merged 9 commits into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ci/azure_pipelines/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
name: $(BuildDefinitionName)_$(SourceBranchName)_$(BuildID)
stages:
- stage: Build
dependsOn: []
jobs:
- job: RayTests
timeoutInMinutes: 240
Expand Down Expand Up @@ -297,6 +298,7 @@ stages:
# Template containing steps to publish artifacts
- template: templates/artifacts.yml
- stage: Tests
dependsOn: []
jobs:
- job: StreamingTests
# Disabled
Expand Down
2 changes: 1 addition & 1 deletion ci/azure_pipelines/templates/info.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ steps:
echo "Please check the changes, change the azure pipelines acordingly and update the sha256"
exit 1
fi
EXPECTED_HASH_CI_FOLDER='aaebf29e0547b0c4a4079e18d9608bbddbb8771d88c70ec3570146a3d2671571'
EXPECTED_HASH_CI_FOLDER='5f5eb0c22adf951eb117e09a60e9e94e336a51a353e9aa776340b5118fc987b4'
CURRENT_HASH_CI_FOLDER=$(find ./ci -path "./ci/azure_pipelines" -prune -o -path "./**/.DS_Store" -prune -o -type f -print0 | sort -z | xargs -0 shasum -a 256 | shasum -a 256 | awk '{print $1}')
if [[ $EXPECTED_HASH_CI_FOLDER != $CURRENT_HASH_CI_FOLDER ]]; then
echo "The original CI folder of the project has changed"
Expand Down
2 changes: 1 addition & 1 deletion ci/travis/test-wheels.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function retry {

if [[ "$platform" == "linux" ]]; then
# Install miniconda.
PY_WHEEL_VERSIONS=("36" "37" "38")
PY_WHEEL_VERSIONS=("cp36" "cp37" "cp38")
PY_MMS=("3.6.9"
"3.7.6"
"3.8.2")
Expand Down
9 changes: 7 additions & 2 deletions python/ray/tests/test_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ def verify_metrics(x):
it12 = it1.union(it2, deterministic=True)
it123 = it12.union(it3, deterministic=True)
out = it123.for_each(verify_metrics)
assert out.take(20) == [1, 1, 1, 2, 2, 3, 2, 4, 3, 3, 4, 4]
taken = out.take(20)
expected = [1, 1, 1, 2, 2, 3, 2, 4, 3, 3, 4, 4]
assert len(taken) == len(expected)
assert taken == expected


def test_from_items(ray_start_regular_shared):
Expand Down Expand Up @@ -469,7 +472,9 @@ def gen_slow():
".gather_async()], LocalIterator[ParallelIterator["
"from_iterators[shards=1].for_each()].gather_async()]]]")
results = list(it)
assert all(x[0] == "slow" for x in results[-3:]), results
slow_count = sum(1 for x in results if x[0] == "slow")
assert slow_count >= 1
assert (len(results) - slow_count) >= 8


def test_serialization(ray_start_regular_shared):
Expand Down
21 changes: 18 additions & 3 deletions python/ray/util/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,9 @@ def union(self,
as many items from the first iterator as the second.
[2, 1, "*"] will cause as many items to be pulled as possible
from the third iterator without blocking. This overrides the
deterministic flag.
deterministic flag. If weights has fixed values, we will Stop
the Iteration if any of the iterators with fixed weight stop
its iteration before metting the expected count.
"""

for it in others:
Expand Down Expand Up @@ -1036,23 +1038,36 @@ def union(self,
active = list(zip(round_robin_weights, active))

def build_union(timeout=None):
MAX_PULL = 100 # TOOD(ekl) how to best bound this?
pull_counts = [0] * len(active)
while True:
for weight, it in list(active):
yield_counts = [0] * len(active)
for i, (weight, it) in enumerate(list(active)):
if weight == "*":
max_pull = 100 # TOOD(ekl) how to best bound this?
max_pull = MAX_PULL
else:
max_pull = _randomized_int_cast(weight)
try:
for _ in range(max_pull):
pull_counts[i] += 1
item = next(it)
if isinstance(item, _NextValueNotReady):
if timeout is not None:
yield item
break
else:
yield_counts[i] += 1
yield item
except StopIteration:
active.remove((weight, it))
fix_weights = [
w != "*" for w in round_robin_weights
]
expected_yield_counts = weight if weight != "*" else MAX_PULL
if (any(fix_weights) and
yield_counts[i] < expected_yield_counts and
pull_counts[i] >= MAX_PULL):
raise
if not active:
break

Expand Down
9 changes: 9 additions & 0 deletions rllib/agents/ddpg/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@
"timesteps_per_iteration": 25000,
"worker_side_prioritization": True,
"min_iter_time_s": 30,
# If set, this will fix the ratio of sampled to replayed timesteps.
# Otherwise, replay will proceed as fast as possible.
"training_intensity": None,
# Which mode to use in the ParallelRollouts operator used to collect
# samples. For more details check the operator in rollout_ops module.
"parallel_rollouts_mode": "async",
# This only applies if async mode is used (above config setting).
# Controls the max number of async requests in flight per actor
"parallel_rollouts_num_async": 2,
},
)

Expand Down
6 changes: 6 additions & 0 deletions rllib/agents/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@
"worker_side_prioritization": False,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 1,
# Which mode to use in the ParallelRollouts operator used to collect
# samples. For more details check the operator in rollout_ops module.
"parallel_rollouts_mode": "bulk_sync",
# This only applies if async mode is used (above config setting).
# Controls the max number of async requests in flight per actor
"parallel_rollouts_num_async": None,

# Deprecated keys.
"parameter_noise": DEPRECATED_VALUE,
Expand Down
13 changes: 12 additions & 1 deletion rllib/agents/dqn/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@
# If set, this will fix the ratio of sampled to replayed timesteps.
# Otherwise, replay will proceed as fast as possible.
"training_intensity": None,
# Which mode to use in the ParallelRollouts operator used to collect
# samples. For more details check the operator in rollout_ops module.
"parallel_rollouts_mode": "async",
# This only applies if async mode is used (above config setting).
# Controls the max number of async requests in flight per actor
"parallel_rollouts_num_async": 2,
},
)
# __sphinx_doc_end__
Expand Down Expand Up @@ -107,7 +113,12 @@ def update_prio_and_stats(item: ("ActorHandle", dict, int)):
# We execute the following steps concurrently:
# (1) Generate rollouts and store them in our replay buffer actors. Update
# the weights of the worker that generated the batch.
rollouts = ParallelRollouts(workers, mode="async", num_async=2)
parallel_rollouts_mode = config.get("parallel_rollouts_mode", "async")
num_async = config.get("parallel_rollouts_num_async")
# This could be set to None explicitly
if not num_async:
num_async = 2
rollouts = ParallelRollouts(workers, mode=parallel_rollouts_mode, num_async=num_async)
store_op = rollouts \
.for_each(StoreToReplayBuffer(actors=replay_actors))
if config.get("execution_plan_custom_store_ops"):
Expand Down
13 changes: 12 additions & 1 deletion rllib/agents/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@
"worker_side_prioritization": False,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 1,
# Which mode to use in the ParallelRollouts operator used to collect
# samples. For more details check the operator in rollout_ops module.
"parallel_rollouts_mode": "bulk_sync",
# This only applies if async mode is used (above config setting).
# Controls the max number of async requests in flight per actor
"parallel_rollouts_num_async": None,

# DEPRECATED VALUES (set to -1 to indicate they have not been overwritten
# by user's config). If we don't set them here, we will get an error
Expand Down Expand Up @@ -250,7 +256,12 @@ def execution_plan(workers, config):
multiagent_sync_replay=config.get("multiagent_sync_replay"),
**prio_args)

rollouts = ParallelRollouts(workers, mode="bulk_sync")
parallel_rollouts_mode = config.get("parallel_rollouts_mode", "bulk_sync")
num_async = config.get("parallel_rollouts_num_async")
# This could be set to None explicitly
if not num_async:
num_async = 1
rollouts = ParallelRollouts(workers, mode=parallel_rollouts_mode, num_async=num_async)

# We execute the following steps concurrently:
# (1) Generate rollouts and store them in our local replay buffer. Calling
Expand Down
13 changes: 12 additions & 1 deletion rllib/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@
# Whether to fake GPUs (using CPUs).
# Set this to True for debugging on non-GPU machines (set `num_gpus` > 0).
"_fake_gpus": False,
# Which mode to use in the ParallelRollouts operator used to collect
# samples. For more details check the operator in rollout_ops module.
"parallel_rollouts_mode": "bulk_sync",
# This only applies if async mode is used (above config setting).
# Controls the max number of async requests in flight per actor
"parallel_rollouts_num_async": None,
})
# __sphinx_doc_end__
# yapf: enable
Expand Down Expand Up @@ -173,7 +179,12 @@ def update(pi, pi_id):


def execution_plan(workers, config):
rollouts = ParallelRollouts(workers, mode="bulk_sync")
parallel_rollouts_mode = config.get("parallel_rollouts_mode", "bulk_sync")
num_async = config.get("parallel_rollouts_num_async")
# This could be set to None explicitly
if not num_async:
num_async = 1
rollouts = ParallelRollouts(workers, mode=parallel_rollouts_mode, num_async=num_async)

if config.get("execution_plan_custom_store_ops"):
custom_store_ops = config["execution_plan_custom_store_ops"]
Expand Down
6 changes: 6 additions & 0 deletions rllib/agents/sac/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
# If set, this will fix the ratio of sampled to replayed timesteps.
# Otherwise, replay will proceed as fast as possible.
"training_intensity": None,
# Which mode to use in the ParallelRollouts operator used to collect
# samples. For more details check the operator in rollout_ops module.
"parallel_rollouts_mode": "async",
# This only applies if async mode is used (above config setting).
# Controls the max number of async requests in flight per actor
"parallel_rollouts_num_async": 2,
},
)

Expand Down
6 changes: 6 additions & 0 deletions rllib/agents/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@
"worker_side_prioritization": False,
# Prevent iterations from going lower than this time span.
"min_iter_time_s": 1,
# Which mode to use in the ParallelRollouts operator used to collect
# samples. For more details check the operator in rollout_ops module.
"parallel_rollouts_mode": "bulk_sync",
# This only applies if async mode is used (above config setting).
# Controls the max number of async requests in flight per actor
"parallel_rollouts_num_async": None,

# Whether the loss should be calculated deterministically (w/o the
# stochastic action sampling step). True only useful for cont. actions and
Expand Down
9 changes: 9 additions & 0 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,15 @@
"extra_python_environs_for_driver": {},
# The extra python environments need to set for worker processes.
"extra_python_environs_for_worker": {},
# If set, this will fix the ratio of sampled to replayed timesteps.
# Otherwise, replay will proceed as fast as possible.
"training_intensity": None,
# Which mode to use in the ParallelRollouts operator used to collect
# samples. For more details check the operator in rollout_ops module.
"parallel_rollouts_mode": None,
# This only applies if async mode is used (above config setting).
# Controls the max number of async requests in flight per actor
"parallel_rollouts_num_async": None,

# === Advanced Resource Settings ===
# Number of CPUs to allocate per worker.
Expand Down
6 changes: 6 additions & 0 deletions rllib/contrib/maddpg/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@
"num_workers": 1,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 0,
# Which mode to use in the ParallelRollouts operator used to collect
# samples. For more details check the operator in rollout_ops module.
"parallel_rollouts_mode": "bulk_sync",
# This only applies if async mode is used (above config setting).
# Controls the max number of async requests in flight per actor
"parallel_rollouts_num_async": None,
})
# __sphinx_doc_end__
# yapf: enable
Expand Down
18 changes: 13 additions & 5 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,7 @@ def __init__(self,
for key, value in extra_python_environs.items():
os.environ[key] = str(value)

def gen_rollouts():
while True:
yield self.sample()

ParallelIteratorWorker.__init__(self, gen_rollouts, False)
ParallelIteratorWorker.__init__(self, self.gen_rollouts, False)

policy_config = policy_config or {}
if (tf and policy_config.get("framework") == "tfe"
Expand Down Expand Up @@ -503,6 +499,18 @@ def make_env(vector_index):
"Created rollout worker with env {} ({}), policies {}".format(
self.async_env, self.env, self.policy_map))

@DeveloperAPI
def gen_rollouts(self):
"""Simple generator of rollouts.
This generator is used by the ParallelRollout operators to produce
samples using the Ray ParallelIterator API.

Child classes could override this method if a custom generator function
is required.
"""
while True:
yield self.sample()

@DeveloperAPI
def sample(self):
"""Returns a batch of experience sampled from this worker.
Expand Down