diff --git a/rllib/BUILD b/rllib/BUILD index 097d0355b3d1..736136eebe76 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -462,7 +462,7 @@ py_test( name = "learning_tests_cartpole_marwil", main = "tuned_examples/marwil/cartpole_marwil.py", tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_cartpole", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], - size = "medium", + size = "large", srcs = ["tuned_examples/marwil/cartpole_marwil.py"], # Include the zipped json data file as well. data = [ diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 78ce811c1eb0..9316d46f8522 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -431,6 +431,7 @@ def __init__(self, algo_class: Optional[type] = None): self.input_read_method = "read_parquet" self.input_read_method_kwargs = {} self.input_read_schema = {} + self.input_read_episodes = False self.map_batches_kwargs = {} self.iter_batches_kwargs = {} self.prelearner_class = None @@ -2385,6 +2386,7 @@ def offline_data( input_read_method: Optional[Union[str, Callable]] = NotProvided, input_read_method_kwargs: Optional[Dict] = NotProvided, input_read_schema: Optional[Dict[str, str]] = NotProvided, + input_read_episodes: Optional[bool] = NotProvided, map_batches_kwargs: Optional[Dict] = NotProvided, iter_batches_kwargs: Optional[Dict] = NotProvided, prelearner_class: Optional[Type] = NotProvided, @@ -2437,6 +2439,16 @@ def offline_data( schema used is `ray.rllib.offline.offline_data.SCHEMA`. If your data set contains already the names in this schema, no `input_read_schema` is needed. + input_read_episodes: If offline data is already stored in RLlib's + `EpisodeType` format, i.e. `ray.rllib.env.SingleAgentEpisode` (multi + -agent is planned but not supported, yet). Reading directly episodes + avoids an additional transforming step and is usually faster and + therefore the adviced format when your application remains fully inside + of RLlib's schema. The other format is a columnar format and is agnostic + to the RL framework used. Use the latter format, if you are unsure when + to use the data or in which RL framework. The default is to read column + data, i.e. `False`. See also `output_write_episodes` to define the + output data format when recording. map_batches_kwargs: `kwargs` for the `map_batches` method. These will be passed into the `ray.data.Dataset.map_batches` method when sampling without checking. If no arguments passed in the default arguments `{ @@ -2528,6 +2540,8 @@ def offline_data( self.input_read_method_kwargs = input_read_method_kwargs if input_read_schema is not NotProvided: self.input_read_schema = input_read_schema + if input_read_episodes is not NotProvided: + self.input_read_episodes = input_read_episodes if map_batches_kwargs is not NotProvided: self.map_batches_kwargs = map_batches_kwargs if iter_batches_kwargs is not NotProvided: diff --git a/rllib/algorithms/marwil/marwil.py b/rllib/algorithms/marwil/marwil.py index d42ccfa3b5e7..9ba3e937b7e7 100644 --- a/rllib/algorithms/marwil/marwil.py +++ b/rllib/algorithms/marwil/marwil.py @@ -192,7 +192,7 @@ def get_default_rl_module_spec(self) -> RLModuleSpecType: else: raise ValueError( f"The framework {self.framework_str} is not supported. " - "Use either 'torch' or 'tf2'." + "Use 'torch' instead." ) @override(AlgorithmConfig) @@ -205,7 +205,8 @@ def get_default_learner_class(self) -> Union[Type["Learner"], str]: return MARWILTorchLearner else: raise ValueError( - f"The framework {self.framework_str} is not supported. " "Use 'torch'." + f"The framework {self.framework_str} is not supported. " + "Use 'torch' instead." ) @override(AlgorithmConfig) @@ -324,7 +325,7 @@ def training_step(self) -> ResultDict: elif self.config.enable_rl_module_and_learner: raise ValueError( "`enable_rl_module_and_learner=True`. Hybrid stack is not " - "is not supported for MARWIL. Either use the old stack with " + "supported for MARWIL. Either use the old stack with " "`ModelV2` or the new stack with `RLModule`. You can enable " "the new stack by setting both, `enable_rl_module_and_learner` " "and `enable_env_runner_and_connector_v2` to `True`." diff --git a/rllib/algorithms/marwil/marwil_catalog.py b/rllib/algorithms/marwil/marwil_catalog.py index acabfa13fb11..16f1a2d7e847 100644 --- a/rllib/algorithms/marwil/marwil_catalog.py +++ b/rllib/algorithms/marwil/marwil_catalog.py @@ -16,8 +16,22 @@ class MARWILCatalog(Catalog): """The Catalog class used to build models for MARWIL. MARWILCatalog provides the following models: - - + - ActorCriticEncoder: The encoder used to encode the observations. + - Pi Head: The head used to compute the policy logits. + - Value Function Head: The head used to compute the value function. + + The ActorCriticEncoder is a wrapper around Encoders to produce separate outputs + for the policy and value function. See implementations of MARWILRLModule for + more details. + + ny custom ActorCriticEncoder can be built by overriding the + build_actor_critic_encoder() method. Alternatively, the ActorCriticEncoderConfig + at MARWILCatalog.actor_critic_encoder_config can be overridden to build a custom + ActorCriticEncoder during RLModule runtime. + + Any custom head can be built by overriding the build_pi_head() and build_vf_head() + methods. Alternatively, the PiHeadConfig and VfHeadConfig can be overridden to + build custom heads during RLModule runtime. """ def __init__( diff --git a/rllib/offline/offline_prelearner.py b/rllib/offline/offline_prelearner.py index 72c1e52f0f11..07a291c882f2 100644 --- a/rllib/offline/offline_prelearner.py +++ b/rllib/offline/offline_prelearner.py @@ -86,6 +86,7 @@ def __init__( ): self.config = config + self.input_read_episodes = self.config.input_read_episodes # We need this learner to run the learner connector pipeline. # If it is a `Learner` instance, the `Learner` is local. if isinstance(learner, Learner): @@ -130,10 +131,17 @@ def __init__( @OverrideToImplementCustomLogic def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]: - # Map the batch to episodes. - episodes = self._map_to_episodes( - self._is_multi_agent, batch, schema=SCHEMA | self.config.input_read_schema - ) + + # If we directly read in episodes we just convert to list. + if self.input_read_episodes: + episodes = batch["item"].tolist() + # Otherwise we ap the batch to episodes. + else: + episodes = self._map_to_episodes( + self._is_multi_agent, + batch, + schema=SCHEMA | self.config.input_read_schema, + )["episodes"] # TODO (simon): Make synching work. Right now this becomes blocking or never # receives weights. Learners appear to be non accessable via other actors. # Increase the counter for updating the module. @@ -165,7 +173,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]] batch = self._learner_connector( rl_module=self._module, data={}, - episodes=episodes["episodes"], + episodes=episodes, shared_data={}, ) # Convert to `MultiAgentBatch`. @@ -176,7 +184,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]] }, # TODO (simon): This can be run once for the batch and the # metrics, but we run it twice: here and later in the learner. - env_steps=sum(e.env_steps() for e in episodes["episodes"]), + env_steps=sum(e.env_steps() for e in episodes), ) # Remove all data from modules that should not be trained. We do # not want to pass around more data than necessaty.