Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Make Dataset reader default reader and enable CRR to use dataset #26304

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
9000822
Temp
avnishn Jun 23, 2022
d4f58be
Lint
avnishn Jun 23, 2022
4609ba3
Make the dataset and json readers batchable
avnishn Jun 24, 2022
7d3f3f3
Make batch size respective to num workers and ad __main__ to tests
avnishn Jun 24, 2022
1c2df64
Merge branch 'batchable_input_reader' of https://github.com/avnishn/r…
avnishn Jun 24, 2022
dccdf22
Temp
avnishn Jun 27, 2022
a827949
Temp
avnishn Jun 27, 2022
9899634
Make batch size correct
avnishn Jun 27, 2022
49fd38d
Merge branch 'batchable_input_reader' of https://github.com/avnishn/r…
avnishn Jun 27, 2022
f60321a
Address Review Comments
avnishn Jun 27, 2022
4e821c8
Address Review Comments
avnishn Jun 27, 2022
bd893fa
Merge branch 'master' of https://github.com/ray-project/ray into batc…
avnishn Jun 27, 2022
5820345
CRR works with dataset readers
avnishn Jun 27, 2022
d93a086
Lint
avnishn Jun 27, 2022
c805da3
Merge branch 'batchable_input_reader' of https://github.com/avnishn/r…
avnishn Jun 27, 2022
976ad16
Fix bug
avnishn Jun 27, 2022
0b02014
Merge branch 'batchable_input_reader' of https://github.com/avnishn/r…
avnishn Jun 27, 2022
407aac3
Temp
avnishn Jun 29, 2022
4439c59
Merge branch 'master' of https://github.com/ray-project/ray into crr_…
avnishn Jun 29, 2022
472cf1c
Working with pendulum
avnishn Jun 29, 2022
f5cd45e
Remove debug statements, comment out CRR tests
avnishn Jun 30, 2022
b062014
Lint
avnishn Jun 30, 2022
85192d6
Fixed issues with runtime cr
avnishn Jul 6, 2022
6d258c4
Lint
avnishn Jul 6, 2022
49abfca
Remove ipdb statement
avnishn Jul 6, 2022
e7baf76
Merge branch 'master' of https://github.com/ray-project/ray into make…
Jul 6, 2022
7ae843b
Change path to paths
Jul 6, 2022
7170438
Add support for file directories
Jul 6, 2022
e087b6a
Use JSONReader for backwards compatibility in test_nested_action_spac…
avnishn Jul 6, 2022
aeafb51
Fix linter
avnishn Jul 6, 2022
fd8da12
Add base option for passing a single path
avnishn Jul 6, 2022
cee6929
Merge branch 'master' of https://github.com/ray-project/ray into make…
avnishn Jul 6, 2022
3bc7fd2
Merge branch 'crr_use_worker_data_2' of https://github.com/avnishn/ra…
avnishn Jul 6, 2022
b0da44a
Remove replay buffer from cql
avnishn Jul 6, 2022
adc1fe0
Remove pendulum json file
avnishn Jul 6, 2022
5ef123d
Enable CRR learning tests
avnishn Jul 6, 2022
9bbc381
Lint
avnishn Jul 6, 2022
f7cb828
Lint
avnishn Jul 6, 2022
b7f758a
Fix functionality for custom input workers
avnishn Jul 7, 2022
616f41c
Fix broken cql test
avnishn Jul 7, 2022
825fc99
Fix broken custom input reader test
avnishn Jul 7, 2022
79c3abb
Reduce cpus used in pendulum cql to fit in cpu reqs of ci
avnishn Jul 7, 2022
3503360
Reduce cpus in cartpole crr to fit in cpu reqs of ci
avnishn Jul 7, 2022
7ef6b6b
Change hparams in pendulum cql to reach convergence quicker
avnishn Jul 7, 2022
b2ba6f2
Address review feedback and reduce num cpus to fall into cpu reqs for ci
avnishn Jul 7, 2022
fb00ad7
Lint
avnishn Jul 7, 2022
ad1d8fb
Address review comments
avnishn Jul 7, 2022
62cad4a
Address comments about target_update name
avnishn Jul 7, 2022
d9a2134
Address Review comments
avnishn Jul 7, 2022
edadcf3
Replace convoluted uses of os with pathlib
avnishn Jul 7, 2022
0bd5948
Fix bug with path in dataset reader for json files
avnishn Jul 8, 2022
2ba824c
Address review comments
avnishn Jul 8, 2022
11d1b6e
Lint
avnishn Jul 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -254,17 +254,17 @@ py_test(

# CRR
py_test(
name = "learning_tests_pendulum_crr",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "torch_only", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
# Include an offline json data file as well.
data = [
"tuned_examples/crr/pendulum-v1-crr.yaml",
"tests/data/pendulum/pendulum_replay_v1.1.0.zip",
],
args = ["--yaml-dir=tuned_examples/crr"]
name = "learning_tests_pendulum_crr",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "torch_only", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
# Include an offline json data file as well.
data = [
"tuned_examples/crr/pendulum-v1-crr.yaml",
"tests/data/pendulum/pendulum_replay_v1.1.0.zip",
],
args = ["--yaml-dir=tuned_examples/crr"]
)

py_test(
Expand Down
33 changes: 6 additions & 27 deletions rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
multi_gpu_train_one_step,
train_one_step,
)
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import (
Expand All @@ -32,8 +31,8 @@
NUM_TARGET_UPDATES,
TARGET_NET_UPDATE_TIMER,
SYNCH_WORKER_WEIGHTS_TIMER,
SAMPLE_TIMER,
)
from ray.rllib.utils.replay_buffers.utils import update_priorities_in_replay_buffer
from ray.rllib.utils.typing import ResultDict, AlgorithmConfigDict

tf1, tf, tfv = try_import_tf()
Expand Down Expand Up @@ -177,23 +176,11 @@ def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
@override(SAC)
def training_step(self) -> ResultDict:
# Collect SampleBatches from sample workers.
batch = synchronous_parallel_sample(worker_set=self.workers)
batch = batch.as_multi_agent()
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
# Add batch to replay buffer.
self.local_replay_buffer.add(batch)

# Sample training batch from replay buffer.
train_batch = sample_min_n_steps_from_buffer(
self.local_replay_buffer,
self.config["train_batch_size"],
count_by_agent_steps=self._by_agent_steps,
)

# Old-style replay buffers return None if learning has not started
if not train_batch:
return {}
with self._timers[SAMPLE_TIMER]:
train_batch = synchronous_parallel_sample(worker_set=self.workers)
train_batch = train_batch.as_multi_agent()
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()

# Postprocess batch before we learn on it.
post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
Expand All @@ -207,14 +194,6 @@ def training_step(self) -> ResultDict:
else:
train_results = multi_gpu_train_one_step(self, train_batch)

# Update replay buffer priorities.
update_priorities_in_replay_buffer(
self.local_replay_buffer,
self.config,
train_batch,
train_results,
)

# Update target network every `target_network_update_freq` training steps.
cur_ts = self._counters[
NUM_AGENT_STEPS_TRAINED if self._by_agent_steps else NUM_ENV_STEPS_TRAINED
Expand Down
10 changes: 5 additions & 5 deletions rllib/algorithms/cql/tests/test_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def test_cql_compilation(self):
evaluation_parallel_to_training=False,
evaluation_num_workers=2,
)
.rollouts(rollout_fragment_length=1)
.rollouts(num_rollout_workers=0)
.reporting(min_time_s_per_iteration=0.0)
)
num_iterations = 4

Expand All @@ -85,7 +86,6 @@ def test_cql_compilation(self):
f"iter={trainer.iteration} "
f"R={eval_results['episode_reward_mean']}"
)

check_compute_single_action(trainer)

# Get policy and model.
Expand All @@ -97,9 +97,9 @@ def test_cql_compilation(self):
# Example on how to do evaluation on the trained Trainer
# using the data from CQL's global replay buffer.
# Get a sample (MultiAgentBatch).
multi_agent_batch = trainer.local_replay_buffer.sample(
num_items=config.train_batch_size
)

batch = trainer.workers.local_worker().input_reader.next()
multi_agent_batch = batch.as_multi_agent()
# All experiences have been buffered for `default_policy`
batch = multi_agent_batch.policy_batches["default_policy"]

Expand Down
165 changes: 58 additions & 107 deletions rllib/algorithms/crr/crr.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import logging
from typing import List, Optional, Type

import numpy as np
import tree

from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig
from ray.rllib.execution import synchronous_parallel_sample
from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step
from ray.rllib.offline.shuffled_input import ShuffledInput
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_TRAINED,
NUM_TARGET_UPDATES,
TARGET_NET_UPDATE_TIMER,
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
SAMPLE_TIMER,
)
from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer
from ray.rllib.utils.typing import (
AlgorithmConfigDict,
PartialAlgorithmConfigDict,
Expand All @@ -38,19 +38,13 @@ def __init__(self, algo_class=None):
self.advantage_type = "mean"
self.n_action_sample = 4
self.twin_q = True
self.target_update_grad_intervals = 100
self.train_batch_size = 128

# target_network_update_freq by default is 100 * train_batch_size
# if target_network_update_freq is not set. See self.setup for code.
self.target_network_update_freq = None
# __sphinx_doc_end__
# fmt: on
self.replay_buffer_config = {
"type": MultiAgentReplayBuffer,
"capacity": 50000,
# How many steps of the model to sample before learning starts.
"learning_starts": 1000,
"replay_batch_size": 32,
# The number of contiguous environment steps to replay at once. This
# may be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
}
self.actor_hiddens = [256, 256]
self.actor_hidden_activation = "relu"
self.critic_hiddens = [256, 256]
Expand All @@ -60,7 +54,10 @@ def __init__(self, algo_class=None):
self.tau = 5e-3

# overriding the trainer config default
self.num_workers = 0 # offline RL does not need rollout workers
# If data ingestion/sample_time is slow, increase this
self.num_workers = 4
self.offline_sampling = True
self.min_iter_time_s = 10.0

def training(
self,
Expand All @@ -71,8 +68,7 @@ def training(
advantage_type: Optional[str] = None,
n_action_sample: Optional[int] = None,
twin_q: Optional[bool] = None,
target_update_grad_intervals: Optional[int] = None,
replay_buffer_config: Optional[dict] = None,
target_network_update_freq: Optional[int] = None,
actor_hiddens: Optional[List[int]] = None,
actor_hidden_activation: Optional[str] = None,
critic_hiddens: Optional[List[int]] = None,
Expand Down Expand Up @@ -110,10 +106,9 @@ def training(
a^j)]
n_action_sample: the number of actions to sample for v_t estimation.
twin_q: if True, uses pessimistic q estimation.
target_update_grad_intervals: The frequency at which we update the
target_network_update_freq: The frequency at which we update the
target copy of the model in terms of the number of gradient updates
applied to the main model.
replay_buffer_config: The config dictionary for replay buffer.
actor_hiddens: The number of hidden units in the actor's fc network.
actor_hidden_activation: The activation used in the actor's fc network.
critic_hiddens: The number of hidden units in the critic's fc network.
Expand All @@ -139,10 +134,8 @@ def training(
self.n_action_sample = n_action_sample
if twin_q is not None:
self.twin_q = twin_q
if target_update_grad_intervals is not None:
self.target_update_grad_intervals = target_update_grad_intervals
if replay_buffer_config is not None:
self.replay_buffer_config = replay_buffer_config
if target_network_update_freq is not None:
self.target_network_update_freq = target_network_update_freq
if actor_hiddens is not None:
self.actor_hiddens = actor_hiddens
if actor_hidden_activation is not None:
Expand All @@ -168,44 +161,10 @@ class CRR(Algorithm):

def setup(self, config: PartialAlgorithmConfigDict):
super().setup(config)
# initial setup for handling the offline data in form of a replay buffer
# Add the entire dataset to Replay Buffer (global variable)
reader = self.workers.local_worker().input_reader

# For d4rl, add the D4RLReaders' dataset to the buffer.
if isinstance(self.config["input"], str) and "d4rl" in self.config["input"]:
dataset = reader.dataset
self.local_replay_buffer.add(dataset)
# For a list of files, add each file's entire content to the buffer.
elif isinstance(reader, ShuffledInput):
num_batches = 0
total_timesteps = 0
for batch in reader.child.read_all_files():
num_batches += 1
total_timesteps += len(batch)
# Add NEXT_OBS if not available. This is slightly hacked
# as for the very last time step, we will use next-obs=zeros
# and therefore force-set DONE=True to avoid this missing
# next-obs to cause learning problems.
if SampleBatch.NEXT_OBS not in batch:
obs = batch[SampleBatch.OBS]
batch[SampleBatch.NEXT_OBS] = np.concatenate(
[obs[1:], np.zeros_like(obs[0:1])]
)
batch[SampleBatch.DONES][-1] = True
self.local_replay_buffer.add(batch)
print(
f"Loaded {num_batches} batches ({total_timesteps} ts) into the"
" replay buffer, which has capacity "
f"{self.local_replay_buffer.capacity}."
)
else:
raise ValueError(
"Unknown offline input! config['input'] must either be list of"
" offline files (json) or a D4RL-specific InputReader "
"specifier (e.g. 'd4rl.hopper-medium-v0')."
if self.config.get("target_network_update_freq", None) is None:
self.config["target_network_update_freq"] = (
self.config["train_batch_size"] * 100
)

# added a counter key for keeping track of number of gradient updates
self._counters[NUM_GRADIENT_UPDATES] = 0
# if I don't set this here to zero I won't see zero in the logs (defaultdict)
Expand All @@ -227,47 +186,39 @@ def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:

@override(Algorithm)
def training_step(self) -> ResultDict:

total_transitions = len(self.local_replay_buffer)
bsize = self.config["train_batch_size"]
n_batches_per_epoch = total_transitions // bsize

results = []
for batch_iter in range(n_batches_per_epoch):
# Sample training batch from replay buffer.
train_batch = self.local_replay_buffer.sample(bsize)

# Postprocess batch before we learn on it.
post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
train_batch = post_fn(train_batch, self.workers, self.config)

# Learn on training batch.
# Use simple optimizer (only for multi-agent or tf-eager; all other
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
if self.config.get("simple_optimizer", False):
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)

# update target every few gradient updates
cur_ts = self._counters[NUM_GRADIENT_UPDATES]
last_update = self._counters[LAST_TARGET_UPDATE_TS]

if cur_ts - last_update >= self.config["target_update_grad_intervals"]:
with self._timers[TARGET_NET_UPDATE_TIMER]:
to_update = self.workers.local_worker().get_policies_to_train()
self.workers.local_worker().foreach_policy_to_train(
lambda p, pid: pid in to_update and p.update_target()
)
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts

self._counters[NUM_GRADIENT_UPDATES] += 1

results.append(train_results)

summary = tree.map_structure_with_path(
lambda path, *v: float(np.mean(v)), *results
)

return summary
with self._timers[SAMPLE_TIMER]:
train_batch = synchronous_parallel_sample(worker_set=self.workers)
train_batch = train_batch.as_multi_agent()
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()

# Postprocess batch before we learn on it.
post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
train_batch = post_fn(train_batch, self.workers, self.config)

# Learn on training batch.
# Use simple optimizer (only for multi-agent or tf-eager; all other
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
if self.config.get("simple_optimizer", False):
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)

# update target every few gradient updates
# Update target network every `target_network_update_freq` training steps.
cur_ts = self._counters[
NUM_AGENT_STEPS_TRAINED if self._by_agent_steps else NUM_ENV_STEPS_TRAINED
]
last_update = self._counters[LAST_TARGET_UPDATE_TS]

if cur_ts - last_update >= self.config["target_network_update_freq"]:
with self._timers[TARGET_NET_UPDATE_TIMER]:
to_update = self.workers.local_worker().get_policies_to_train()
self.workers.local_worker().foreach_policy_to_train(
lambda p, pid: pid in to_update and p.update_target()
)
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts

self._counters[NUM_GRADIENT_UPDATES] += 1
return train_results
Loading