Skip to content

Commit

Permalink
[RLlib; Offline RL] Add support to directly read from episodes. (ray-…
Browse files Browse the repository at this point in the history
  • Loading branch information
simonsays1980 authored Aug 12, 2024
1 parent 768108f commit 936427f
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 12 deletions.
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ py_test(
name = "learning_tests_cartpole_marwil",
main = "tuned_examples/marwil/cartpole_marwil.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_cartpole", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"],
size = "medium",
size = "large",
srcs = ["tuned_examples/marwil/cartpole_marwil.py"],
# Include the zipped json data file as well.
data = [
Expand Down
14 changes: 14 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def __init__(self, algo_class: Optional[type] = None):
self.input_read_method = "read_parquet"
self.input_read_method_kwargs = {}
self.input_read_schema = {}
self.input_read_episodes = False
self.map_batches_kwargs = {}
self.iter_batches_kwargs = {}
self.prelearner_class = None
Expand Down Expand Up @@ -2385,6 +2386,7 @@ def offline_data(
input_read_method: Optional[Union[str, Callable]] = NotProvided,
input_read_method_kwargs: Optional[Dict] = NotProvided,
input_read_schema: Optional[Dict[str, str]] = NotProvided,
input_read_episodes: Optional[bool] = NotProvided,
map_batches_kwargs: Optional[Dict] = NotProvided,
iter_batches_kwargs: Optional[Dict] = NotProvided,
prelearner_class: Optional[Type] = NotProvided,
Expand Down Expand Up @@ -2437,6 +2439,16 @@ def offline_data(
schema used is `ray.rllib.offline.offline_data.SCHEMA`. If your data set
contains already the names in this schema, no `input_read_schema` is
needed.
input_read_episodes: If offline data is already stored in RLlib's
`EpisodeType` format, i.e. `ray.rllib.env.SingleAgentEpisode` (multi
-agent is planned but not supported, yet). Reading directly episodes
avoids an additional transforming step and is usually faster and
therefore the adviced format when your application remains fully inside
of RLlib's schema. The other format is a columnar format and is agnostic
to the RL framework used. Use the latter format, if you are unsure when
to use the data or in which RL framework. The default is to read column
data, i.e. `False`. See also `output_write_episodes` to define the
output data format when recording.
map_batches_kwargs: `kwargs` for the `map_batches` method. These will be
passed into the `ray.data.Dataset.map_batches` method when sampling
without checking. If no arguments passed in the default arguments `{
Expand Down Expand Up @@ -2528,6 +2540,8 @@ def offline_data(
self.input_read_method_kwargs = input_read_method_kwargs
if input_read_schema is not NotProvided:
self.input_read_schema = input_read_schema
if input_read_episodes is not NotProvided:
self.input_read_episodes = input_read_episodes
if map_batches_kwargs is not NotProvided:
self.map_batches_kwargs = map_batches_kwargs
if iter_batches_kwargs is not NotProvided:
Expand Down
7 changes: 4 additions & 3 deletions rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def get_default_rl_module_spec(self) -> RLModuleSpecType:
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. "
"Use either 'torch' or 'tf2'."
"Use 'torch' instead."
)

@override(AlgorithmConfig)
Expand All @@ -205,7 +205,8 @@ def get_default_learner_class(self) -> Union[Type["Learner"], str]:
return MARWILTorchLearner
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. " "Use 'torch'."
f"The framework {self.framework_str} is not supported. "
"Use 'torch' instead."
)

@override(AlgorithmConfig)
Expand Down Expand Up @@ -324,7 +325,7 @@ def training_step(self) -> ResultDict:
elif self.config.enable_rl_module_and_learner:
raise ValueError(
"`enable_rl_module_and_learner=True`. Hybrid stack is not "
"is not supported for MARWIL. Either use the old stack with "
"supported for MARWIL. Either use the old stack with "
"`ModelV2` or the new stack with `RLModule`. You can enable "
"the new stack by setting both, `enable_rl_module_and_learner` "
"and `enable_env_runner_and_connector_v2` to `True`."
Expand Down
18 changes: 16 additions & 2 deletions rllib/algorithms/marwil/marwil_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,22 @@ class MARWILCatalog(Catalog):
"""The Catalog class used to build models for MARWIL.
MARWILCatalog provides the following models:
- ActorCriticEncoder: The encoder used to encode the observations.
- Pi Head: The head used to compute the policy logits.
- Value Function Head: The head used to compute the value function.
The ActorCriticEncoder is a wrapper around Encoders to produce separate outputs
for the policy and value function. See implementations of MARWILRLModule for
more details.
ny custom ActorCriticEncoder can be built by overriding the
build_actor_critic_encoder() method. Alternatively, the ActorCriticEncoderConfig
at MARWILCatalog.actor_critic_encoder_config can be overridden to build a custom
ActorCriticEncoder during RLModule runtime.
Any custom head can be built by overriding the build_pi_head() and build_vf_head()
methods. Alternatively, the PiHeadConfig and VfHeadConfig can be overridden to
build custom heads during RLModule runtime.
"""

def __init__(
Expand Down
20 changes: 14 additions & 6 deletions rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
):

self.config = config
self.input_read_episodes = self.config.input_read_episodes
# We need this learner to run the learner connector pipeline.
# If it is a `Learner` instance, the `Learner` is local.
if isinstance(learner, Learner):
Expand Down Expand Up @@ -130,10 +131,17 @@ def __init__(

@OverrideToImplementCustomLogic
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]:
# Map the batch to episodes.
episodes = self._map_to_episodes(
self._is_multi_agent, batch, schema=SCHEMA | self.config.input_read_schema
)

# If we directly read in episodes we just convert to list.
if self.input_read_episodes:
episodes = batch["item"].tolist()
# Otherwise we ap the batch to episodes.
else:
episodes = self._map_to_episodes(
self._is_multi_agent,
batch,
schema=SCHEMA | self.config.input_read_schema,
)["episodes"]
# TODO (simon): Make synching work. Right now this becomes blocking or never
# receives weights. Learners appear to be non accessable via other actors.
# Increase the counter for updating the module.
Expand Down Expand Up @@ -165,7 +173,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]
batch = self._learner_connector(
rl_module=self._module,
data={},
episodes=episodes["episodes"],
episodes=episodes,
shared_data={},
)
# Convert to `MultiAgentBatch`.
Expand All @@ -176,7 +184,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]
},
# TODO (simon): This can be run once for the batch and the
# metrics, but we run it twice: here and later in the learner.
env_steps=sum(e.env_steps() for e in episodes["episodes"]),
env_steps=sum(e.env_steps() for e in episodes),
)
# Remove all data from modules that should not be trained. We do
# not want to pass around more data than necessaty.
Expand Down

0 comments on commit 936427f

Please sign in to comment.