From bb8a286cbc9de8bcc5d225f43e7d2faf047f072e Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Tue, 27 Apr 2021 10:44:54 +0200 Subject: [PATCH] [RLlib] Support native tf.keras.Model (milestone toward obsoleting ModelV2 class). (#14684) --- rllib/BUILD | 2 +- rllib/agents/ddpg/tests/test_ddpg.py | 4 +- rllib/agents/dqn/tests/test_simple_q.py | 4 +- rllib/agents/marwil/marwil_tf_policy.py | 3 +- rllib/agents/pg/tests/test_pg.py | 4 +- rllib/agents/ppo/ppo_tf_policy.py | 38 ++++-- rllib/agents/ppo/tests/test_ppo.py | 8 +- rllib/agents/sac/tests/test_sac.py | 4 +- .../collectors/simple_list_collector.py | 5 +- rllib/evaluation/postprocessing.py | 3 +- rllib/examples/models/modelv3.py | 57 ++++++++ rllib/models/catalog.py | 113 ++++++++++++---- rllib/models/modelv2.py | 67 +-------- rllib/models/tests/test_models.py | 28 ++++ rllib/models/tf/fcnet.py | 128 ++++++++++++++++++ rllib/policy/dynamic_tf_policy.py | 30 ++-- rllib/policy/eager_tf_policy.py | 20 ++- rllib/policy/policy.py | 17 ++- rllib/policy/rnn_sequencing.py | 3 +- rllib/policy/sample_batch.py | 81 +++++++++-- rllib/policy/tests/test_sample_batch.py | 4 +- rllib/policy/tf_policy.py | 15 +- rllib/policy/tf_policy_template.py | 11 +- rllib/utils/exploration/random.py | 6 +- 24 files changed, 488 insertions(+), 167 deletions(-) create mode 100644 rllib/examples/models/modelv3.py diff --git a/rllib/BUILD b/rllib/BUILD index d2ab83c0943ef..e20e88e46466c 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1178,7 +1178,7 @@ py_test( py_test( name = "test_models", tags = ["models"], - size = "small", + size = "medium", srcs = ["models/tests/test_models.py"] ) diff --git a/rllib/agents/ddpg/tests/test_ddpg.py b/rllib/agents/ddpg/tests/test_ddpg.py index c5c83cc6050f8..dd0f29fc915c4 100644 --- a/rllib/agents/ddpg/tests/test_ddpg.py +++ b/rllib/agents/ddpg/tests/test_ddpg.py @@ -414,7 +414,7 @@ def test_ddpg_loss_function(self): trainer.stop() def _get_batch_helper(self, obs_size, actions, batch_size): - return { + return SampleBatch({ SampleBatch.CUR_OBS: np.random.random(size=obs_size), SampleBatch.ACTIONS: actions, SampleBatch.REWARDS: np.random.random(size=(batch_size, )), @@ -422,7 +422,7 @@ def _get_batch_helper(self, obs_size, actions, batch_size): [True, False], size=(batch_size, )), SampleBatch.NEXT_OBS: np.random.random(size=obs_size), "weights": np.ones(shape=(batch_size, )), - } + }) def _ddpg_loss_helper(self, train_batch, weights, ks, fw, gamma, huber_threshold, l2_reg, sess): diff --git a/rllib/agents/dqn/tests/test_simple_q.py b/rllib/agents/dqn/tests/test_simple_q.py index 09524fa21618f..b5caaf5e09664 100644 --- a/rllib/agents/dqn/tests/test_simple_q.py +++ b/rllib/agents/dqn/tests/test_simple_q.py @@ -81,7 +81,7 @@ def test_simple_q_loss_function(self): trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0") policy = trainer.get_policy() # Batch of size=2. - input_ = { + input_ = SampleBatch({ SampleBatch.CUR_OBS: np.random.random(size=(2, 4)), SampleBatch.ACTIONS: np.array([0, 1]), SampleBatch.REWARDS: np.array([0.4, -1.23]), @@ -94,7 +94,7 @@ def test_simple_q_loss_function(self): [-0.1, -0.2]]), SampleBatch.ACTION_PROB: np.array([0.1, 0.2]), "q_values": np.array([[0.1, 0.2], [0.2, 0.1]]), - } + }) # Get model vars for computing expected model outs (q-vals). # 0=layer-kernel; 1=layer-bias; 2=q-val-kernel; 3=q-val-bias vars = policy.get_weights() diff --git a/rllib/agents/marwil/marwil_tf_policy.py b/rllib/agents/marwil/marwil_tf_policy.py index c4abd7fb07b29..e741983b230ad 100644 --- a/rllib/agents/marwil/marwil_tf_policy.py +++ b/rllib/agents/marwil/marwil_tf_policy.py @@ -67,7 +67,8 @@ def postprocess_advantages(policy, # input_dict. # Create an input dict according to the Model's requirements. index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1 - input_dict = policy.model.get_input_dict(sample_batch, index=index) + input_dict = sample_batch.get_single_step_input_dict( + policy.model.view_requirements, index=index) last_r = policy._value(**input_dict) # Adds the "advantages" (which in the case of MARWIL are simply the diff --git a/rllib/agents/pg/tests/test_pg.py b/rllib/agents/pg/tests/test_pg.py index 538818ebd5bf6..2eb3ef32bf13d 100644 --- a/rllib/agents/pg/tests/test_pg.py +++ b/rllib/agents/pg/tests/test_pg.py @@ -70,7 +70,7 @@ def test_pg_loss_functions(self): config["model"]["fcnet_activation"] = "linear" # Fake CartPole episode of n time steps. - train_batch = { + train_batch = SampleBatch({ SampleBatch.OBS: np.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]]), @@ -79,7 +79,7 @@ def test_pg_loss_functions(self): SampleBatch.DONES: np.array([False, False, True]), SampleBatch.EPS_ID: np.array([1234, 1234, 1234]), SampleBatch.AGENT_INDEX: np.array([0, 0, 0]), - } + }) for fw, sess in framework_iterator(config, session=True): dist_cls = (Categorical if fw != "torch" else TorchCategorical) diff --git a/rllib/agents/ppo/ppo_tf_policy.py b/rllib/agents/ppo/ppo_tf_policy.py index 4d1084a6851a0..473950f3a57bb 100644 --- a/rllib/agents/ppo/ppo_tf_policy.py +++ b/rllib/agents/ppo/ppo_tf_policy.py @@ -29,13 +29,15 @@ def ppo_surrogate_loss( - policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], + policy: Policy, model: Union[ModelV2, "tf.keras.Model"], + dist_class: Type[TFActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for Proximal Policy Objective. Args: policy (Policy): The Policy to calculate the loss for. - model (ModelV2): The Model to calculate the loss for. + model (Union[ModelV2, tf.keras.Model]): The Model to calculate + the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. @@ -43,7 +45,13 @@ def ppo_surrogate_loss( Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ - logits, state = model.from_batch(train_batch) + if isinstance(model, tf.keras.Model): + logits, state, extra_outs = model(train_batch) + value_fn_out = extra_outs[SampleBatch.VF_PREDS] + else: + logits, state = model.from_batch(train_batch) + value_fn_out = model.value_function() + curr_action_dist = dist_class(logits, model) # RNN case: Mask away 0-padded chunks at end of time axis. @@ -86,7 +94,6 @@ def reduce_mean_valid(t): if policy.config["use_gae"]: prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] - value_fn_out = model.value_function() vf_loss1 = tf.math.square(value_fn_out - train_batch[Postprocessing.VALUE_TARGETS]) vf_clipped = prev_value_fn_out + tf.clip_by_value( @@ -112,6 +119,7 @@ def reduce_mean_valid(t): policy._mean_vf_loss = mean_vf_loss policy._mean_entropy = mean_entropy policy._mean_kl = mean_kl + policy._value_fn_out = value_fn_out return total_loss @@ -134,14 +142,14 @@ def kl_and_loss_stats(policy: Policy, "policy_loss": policy._mean_policy_loss, "vf_loss": policy._mean_vf_loss, "vf_explained_var": explained_variance( - train_batch[Postprocessing.VALUE_TARGETS], - policy.model.value_function()), + train_batch[Postprocessing.VALUE_TARGETS], policy._value_fn_out), "kl": policy._mean_kl, "entropy": policy._mean_entropy, "entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64), } +# TODO: (sven) Deprecate once we only allow native keras models. def vf_preds_fetches(policy: Policy) -> Dict[str, TensorType]: """Defines extra fetches per action computation. @@ -152,6 +160,10 @@ def vf_preds_fetches(policy: Policy) -> Dict[str, TensorType]: Dict[str, TensorType]: Dict with extra tf fetches to perform per action computation. """ + # Keras models return values for each call in third return argument + # (dict). + if isinstance(policy.model, tf.keras.Model): + return {} # Return value function outputs. VF estimates will hence be added to the # SampleBatches produced by the sampler(s) to generate the train batches # going into the loss function. @@ -177,7 +189,9 @@ def compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer, tuples. """ # Compute the gradients. - variables = policy.model.trainable_variables() + variables = policy.model.trainable_variables + if isinstance(policy.model, ModelV2): + variables = variables() grads_and_vars = optimizer.compute_gradients(loss, variables) # Clip by global norm, if necessary. @@ -267,9 +281,13 @@ def __init__(self, obs_space, action_space, config): @make_tf_callable(self.get_session()) def value(**input_dict): input_dict = SampleBatch(input_dict) - model_out, _ = self.model(input_dict) - # [0] = remove the batch dim. - return self.model.value_function()[0] + if isinstance(self.model, tf.keras.Model): + _, _, extra_outs = self.model(input_dict) + return extra_outs[SampleBatch.VF_PREDS][0] + else: + model_out, _ = self.model(input_dict) + # [0] = remove the batch dim. + return self.model.value_function()[0] # When not doing GAE, we do not require the value function's output. else: diff --git a/rllib/agents/ppo/tests/test_ppo.py b/rllib/agents/ppo/tests/test_ppo.py index c6b7412897ca6..ce100c0ef6890 100644 --- a/rllib/agents/ppo/tests/test_ppo.py +++ b/rllib/agents/ppo/tests/test_ppo.py @@ -20,7 +20,7 @@ check_compute_single_action # Fake CartPole episode of n time steps. -FAKE_BATCH = { +FAKE_BATCH = SampleBatch({ SampleBatch.OBS: np.array( [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]], dtype=np.float32), @@ -35,7 +35,7 @@ SampleBatch.ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32), SampleBatch.EPS_ID: np.array([0, 0, 0]), SampleBatch.AGENT_INDEX: np.array([0, 0, 0]), -} +}) class MyCallbacks(DefaultCallbacks): @@ -82,6 +82,9 @@ def test_ppo_compilation_and_lr_schedule(self): # Settings in case we use an LSTM. config["model"]["lstm_cell_size"] = 10 config["model"]["max_seq_len"] = 20 + # Use default-native keras model whenever possible. + config["model"]["_use_default_native_models"] = True + config["train_batch_size"] = 128 # Test with compression. config["compress_observations"] = True @@ -95,6 +98,7 @@ def test_ppo_compilation_and_lr_schedule(self): config["model"]["use_lstm"] = lstm config["model"]["lstm_use_prev_action"] = lstm config["model"]["lstm_use_prev_reward"] = lstm + trainer = ppo.PPOTrainer(config=config, env=env) for i in range(num_iterations): trainer.train() diff --git a/rllib/agents/sac/tests/test_sac.py b/rllib/agents/sac/tests/test_sac.py index 894f686536a9f..3f41bdfc16bce 100644 --- a/rllib/agents/sac/tests/test_sac.py +++ b/rllib/agents/sac/tests/test_sac.py @@ -448,7 +448,7 @@ def test_sac_loss_function(self): check(tf_var, torch_var, rtol=0.1) def _get_batch_helper(self, obs_size, actions, batch_size): - return { + return SampleBatch({ SampleBatch.CUR_OBS: np.random.random(size=obs_size), SampleBatch.ACTIONS: actions, SampleBatch.REWARDS: np.random.random(size=(batch_size, )), @@ -456,7 +456,7 @@ def _get_batch_helper(self, obs_size, actions, batch_size): [True, False], size=(batch_size, )), SampleBatch.NEXT_OBS: np.random.random(size=obs_size), "weights": np.random.random(size=(batch_size, )), - } + }) def _sac_loss_helper(self, train_batch, weights, ks, log_alpha, fw, gamma, sess): diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 402b7091cddc0..33ed14f639858 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -241,7 +241,7 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: # Due to possible batch-repeats > 1, columns in the resulting batch # may not all have the same batch size. - batch = SampleBatch(batch_data, _dont_check_lens=True) + batch = SampleBatch(batch_data) # Add EPS_ID and UNROLL_ID to batch. batch[SampleBatch.EPS_ID] = np.repeat(self.episode_id, batch.count) @@ -366,8 +366,7 @@ def build(self): this policy. """ # Create batch from our buffers. - batch = SampleBatch( - self.buffers, _seq_lens=self.seq_lens, _dont_check_lens=True) + batch = SampleBatch(self.buffers, _seq_lens=self.seq_lens) # Clear buffers for future samples. self.buffers.clear() # Reset agent steps to 0 and seq-lens to empty list. diff --git a/rllib/evaluation/postprocessing.py b/rllib/evaluation/postprocessing.py index 703cf58a5751c..2be7f39ebd69d 100644 --- a/rllib/evaluation/postprocessing.py +++ b/rllib/evaluation/postprocessing.py @@ -119,7 +119,8 @@ def compute_gae_for_sample_batch( # requirements. It's a single-timestep (last one in trajectory) # input_dict. # Create an input dict according to the Model's requirements. - input_dict = policy.model.get_input_dict(sample_batch, index="last") + input_dict = sample_batch.get_single_step_input_dict( + policy.model.view_requirements, index="last") last_r = policy._value(**input_dict, seq_lens=input_dict.seq_lens) # Adds the policy logits, VF preds, and advantages to the batch, diff --git a/rllib/examples/models/modelv3.py b/rllib/examples/models/modelv3.py new file mode 100644 index 0000000000000..e0f5e01a8bc43 --- /dev/null +++ b/rllib/examples/models/modelv3.py @@ -0,0 +1,57 @@ +import numpy as np + +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.framework import try_import_tf, try_import_torch + +tf1, tf, tfv = try_import_tf() +torch, nn = try_import_torch() + + +class RNNModel(tf.keras.models.Model if tf else object): + """Example of using the Keras functional API to define an RNN model.""" + + def __init__(self, + input_space, + action_space, + num_outputs, + *, + name="", + hiddens_size=256, + cell_size=64): + super().__init__(name=name) + + self.cell_size = cell_size + + # Preprocess observation with a hidden layer and send to LSTM cell + self.dense = tf.keras.layers.Dense( + hiddens_size, activation=tf.nn.relu, name="dense1") + self.lstm = tf.keras.layers.LSTM( + cell_size, return_sequences=True, return_state=True, name="lstm") + + # Postprocess LSTM output with another hidden layer and compute + # values. + self.logits = tf.keras.layers.Dense( + num_outputs, activation=tf.keras.activations.linear, name="logits") + self.values = tf.keras.layers.Dense(1, activation=None, name="values") + + def call(self, sample_batch): + dense_out = self.dense(sample_batch["obs"]) + B = tf.shape(sample_batch.seq_lens)[0] + lstm_in = tf.reshape(dense_out, [B, -1, dense_out.shape.as_list()[1]]) + lstm_out, h, c = self.lstm( + inputs=lstm_in, + mask=tf.sequence_mask(sample_batch.seq_lens), + initial_state=[ + sample_batch["state_in_0"], sample_batch["state_in_1"] + ], + ) + lstm_out = tf.reshape(lstm_out, [-1, lstm_out.shape.as_list()[2]]) + logits = self.logits(lstm_out) + values = tf.reshape(self.values(lstm_out), [-1]) + return logits, [h, c], {SampleBatch.VF_PREDS: values} + + def get_initial_state(self): + return [ + np.zeros(self.cell_size, np.float32), + np.zeros(self.cell_size, np.float32), + ] diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 259949aba3ac8..fbbe875e6f369 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -19,7 +19,8 @@ TorchDeterministic, TorchDiagGaussian, \ TorchMultiActionDistribution, TorchMultiCategorical from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI -from ray.rllib.utils.deprecation import DEPRECATED_VALUE +from ray.rllib.utils.deprecation import DEPRECATED_VALUE, \ + deprecation_warning from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.spaces.simplex import Simplex @@ -34,6 +35,14 @@ # yapf: disable # __sphinx_doc_begin__ MODEL_DEFAULTS: ModelConfigDict = { + # Experimental flag. + # If True, try to use a native (tf.keras.Model or torch.Module) default + # model instead of our built-in ModelV2 defaults. + # If False (default), use "classic" ModelV2 default models. + # Note that this currently only works for framework != torch AND fully + # connected default networks. + "_use_default_native_models": False, + # === Built-in options === # FullyConnectedNetwork (tf and torch): rllib.models.tf|torch.fcnet.py # These are used if no custom model is specified and the input space is 1D. @@ -392,10 +401,13 @@ def get_model_v2(obs_space: gym.Space, model_cls = _global_registry.get(RLLIB_MODEL, model_config["custom_model"]) + # Only allow ModelV2 or native keras Models. if not issubclass(model_cls, ModelV2): - raise ValueError( - "`model_cls` must be a ModelV2 sub-class, but is" - " {}!".format(model_cls)) + if framework not in ["tf", "tf2", "tfe"] or \ + not issubclass(model_cls, tf.keras.Model): + raise ValueError( + "`model_cls` must be a ModelV2 sub-class, but is" + " {}!".format(model_cls)) logger.info("Wrapping {} as {}".format(model_cls, model_interface)) model_cls = ModelCatalog._wrap_if_needed(model_cls, @@ -427,30 +439,51 @@ def track_var_creation(next_creator, **kw): return v with tf.variable_creator_scope(track_var_creation): - # Try calling with kwargs first (custom ModelV2 should - # accept these as kwargs, not get them from - # config["custom_model_config"] anymore). - try: - instance = model_cls(obs_space, action_space, - num_outputs, model_config, name, - **customized_model_kwargs) - except TypeError as e: - # Keyword error: Try old way w/o kwargs. - if "__init__() got an unexpected " in e.args[0]: - instance = model_cls(obs_space, action_space, - num_outputs, model_config, - name, **model_kwargs) - logger.warning( - "Custom ModelV2 should accept all custom " - "options as **kwargs, instead of expecting" - " them in config['custom_model_config']!") - # Other error -> re-raise. - else: - raise e + if issubclass(model_cls, tf.keras.Model): + instance = model_cls( + input_space=obs_space, + action_space=action_space, + num_outputs=num_outputs, + name=name, + **customized_model_kwargs, + ) + else: + # Try calling with kwargs first (custom ModelV2 should + # accept these as kwargs, not get them from + # config["custom_model_config"] anymore). + try: + instance = model_cls( + obs_space, + action_space, + num_outputs, + model_config, + name, + **customized_model_kwargs, + ) + except TypeError as e: + # Keyword error: Try old way w/o kwargs. + if "__init__() got an unexpected " in e.args[0]: + instance = model_cls( + obs_space, + action_space, + num_outputs, + model_config, + name, + **model_kwargs, + ) + logger.warning( + "Custom ModelV2 should accept all custom " + "options as **kwargs, instead of expecting" + " them in config['custom_model_config']!") + # Other error -> re-raise. + else: + raise e # User still registered TFModelV2's variables: Check, whether # ok. - registered = set(instance.var_list) + registered = [] + if not isinstance(instance, tf.keras.Model): + registered = set(instance.var_list) if len(registered) > 0: not_registered = set() for var in created: @@ -544,6 +577,15 @@ def track_var_creation(next_creator, **kw): # Wrap in the requested interface. wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface) + + if issubclass(wrapper, tf.keras.Model): + return wrapper( + input_space=obs_space, + action_space=action_space, + num_outputs=num_outputs, + name=name, + **dict(model_kwargs, **model_config), + ) return wrapper(obs_space, action_space, num_outputs, model_config, name, **model_kwargs) @@ -673,6 +715,8 @@ def register_custom_model(model_name: str, model_class: type) -> None: model_name (str): Name to register the model under. model_class (type): Python class of the model. """ + if issubclass(model_class, tf.keras.Model): + deprecation_warning(old="register_custom_model", error=False) _global_registry.register(RLLIB_MODEL, model_name, model_class) @staticmethod @@ -693,11 +737,11 @@ def register_custom_action_dist(action_dist_name: str, @staticmethod def _wrap_if_needed(model_cls: type, model_interface: type) -> type: - assert issubclass(model_cls, ModelV2), model_cls - if not model_interface or issubclass(model_cls, model_interface): return model_cls + assert issubclass(model_cls, ModelV2), model_cls + class wrapper(model_interface, model_cls): pass @@ -714,10 +758,12 @@ def _get_v2_model_class(input_space: gym.Space, VisionNet = None ComplexNet = None + Keras_FCNet = None if framework in ["tf2", "tf", "tfe"]: from ray.rllib.models.tf.fcnet import \ - FullyConnectedNetwork as FCNet + FullyConnectedNetwork as FCNet, \ + Keras_FullyConnectedNetwork as Keras_FCNet from ray.rllib.models.tf.visionnet import \ VisionNetwork as VisionNet from ray.rllib.models.tf.complex_input_net import \ @@ -751,12 +797,19 @@ def _get_v2_model_class(input_space: gym.Space, for s in space_to_check.spaces)): return ComplexNet - # Single, flattenable/one-hot-abe space -> Simple FCNet. + # Single, flattenable/one-hot-able space -> Simple FCNet. if isinstance(input_space, (Discrete, MultiDiscrete)) or \ len(input_space.shape) == 1 or ( len(input_space.shape) == 2 and ( num_framestacks == "auto" or num_framestacks <= 1)): - return FCNet + # Keras native requested AND no auto-rnn-wrapping AND . + if model_config.get("_use_default_native_models") and \ + Keras_FCNet and not model_config.get("use_lstm") and \ + not model_config.get("use_attention"): + return Keras_FCNet + # Classic ModelV2 FCNet. + else: + return FCNet elif framework == "jax": raise NotImplementedError("No non-FC default net for JAX yet!") diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index e869220efcb81..557d0a69defde 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -170,7 +170,7 @@ def metrics(self) -> Dict[str, TensorType]: key2: metric2 Returns: - Dict of string keys to scalar tensors. + Dict[str, TensorType]: The custom metrics for this model. """ return {} @@ -333,71 +333,6 @@ def is_time_major(self) -> bool: """ return self.time_major is True - # TODO: (sven) Experimental method. - def get_input_dict(self, - sample_batch: SampleBatch, - index: Union[int, str] = "last") -> SampleBatch: - """Creates single ts input-dict at given index from a SampleBatch. - - Args: - sample_batch (SampleBatch): A single-trajectory SampleBatch object - to generate the compute_actions input dict from. - index (Union[int, str]): An integer index value indicating the - position in the trajectory for which to generate the - compute_actions input dict. Set to "last" to generate the dict - at the very end of the trajectory (e.g. for value estimation). - Note that "last" is different from -1, as "last" will use the - final NEXT_OBS as observation input. - - Returns: - SampleBatch: The (single-timestep) input dict for ModelV2 calls. - """ - last_mappings = { - SampleBatch.OBS: SampleBatch.NEXT_OBS, - SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS, - SampleBatch.PREV_REWARDS: SampleBatch.REWARDS, - } - - input_dict = {} - for view_col, view_req in self.view_requirements.items(): - # Create batches of size 1 (single-agent input-dict). - data_col = view_req.data_col or view_col - if index == "last": - data_col = last_mappings.get(data_col, data_col) - # Range needed. - if view_req.shift_from is not None: - data = sample_batch[view_col][-1] - traj_len = len(sample_batch[data_col]) - missing_at_end = traj_len % view_req.batch_repeat_value - obs_shift = -1 if data_col in [ - SampleBatch.OBS, SampleBatch.NEXT_OBS - ] else 0 - from_ = view_req.shift_from + obs_shift - to_ = view_req.shift_to + obs_shift + 1 - if to_ == 0: - to_ = None - input_dict[view_col] = np.array([ - np.concatenate([ - data, sample_batch[data_col][-missing_at_end:] - ])[from_:to_] - ]) - # Single index. - else: - data = sample_batch[data_col][-1] - input_dict[view_col] = np.array([data]) - else: - # Index range. - if isinstance(index, tuple): - data = sample_batch[data_col][index[0]:index[1] + 1 - if index[1] != -1 else None] - input_dict[view_col] = np.array([data]) - # Single index. - else: - input_dict[view_col] = sample_batch[data_col][ - index:index + 1 if index != -1 else None] - - return SampleBatch(input_dict, _seq_lens=np.array([1], dtype=np.int32)) - @DeveloperAPI def flatten(obs: TensorType, framework: str) -> TensorType: diff --git a/rllib/models/tests/test_models.py b/rllib/models/tests/test_models.py index 424dea16cb018..1069929501d4f 100644 --- a/rllib/models/tests/test_models.py +++ b/rllib/models/tests/test_models.py @@ -2,6 +2,9 @@ import numpy as np import unittest +import ray +import ray.rllib.agents.ppo as ppo +from ray.rllib.examples.models.modelv3 import RNNModel from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.fcnet import FullyConnectedNetwork from ray.rllib.utils.framework import try_import_tf @@ -33,6 +36,14 @@ def forward(self, input_dict, state, seq_lens): class TestModels(unittest.TestCase): """Tests ModelV2 classes and their modularization capabilities.""" + @classmethod + def setUpClass(cls) -> None: + ray.init() + + @classmethod + def tearDownClass(cls) -> None: + ray.shutdown() + def test_tf_modelv2(self): obs_space = Box(-1.0, 1.0, (3, )) action_space = Box(-1.0, 1.0, (2, )) @@ -52,6 +63,23 @@ def test_tf_modelv2(self): self.assertTrue("fc_net.base_model.value_out.kernel:0" in vars) self.assertTrue("fc_net.base_model.value_out.bias:0" in vars) + def test_modelv3(self): + config = { + "env": "CartPole-v0", + "model": { + "custom_model": RNNModel, + "custom_model_config": { + "hiddens_size": 64, + "cell_size": 128, + }, + }, + "num_workers": 0, + } + trainer = ppo.PPOTrainer(config=config) + for _ in range(2): + results = trainer.train() + print(results) + if __name__ == "__main__": import pytest diff --git a/rllib/models/tf/fcnet.py b/rllib/models/tf/fcnet.py index 9b0e8c5653742..b511cfcfb9305 100644 --- a/rllib/models/tf/fcnet.py +++ b/rllib/models/tf/fcnet.py @@ -1,9 +1,11 @@ import numpy as np import gym +from typing import Optional, Sequence, Tuple from ray.rllib.models.tf.misc import normc_initializer from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.utils import get_activation_fn +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict @@ -127,3 +129,129 @@ def forward(self, input_dict: Dict[str, TensorType], def value_function(self) -> TensorType: return tf.reshape(self._value_out, [-1]) + + +class Keras_FullyConnectedNetwork(tf.keras.Model if tf else object): + """Generic fully connected network implemented in tf Keras.""" + + def __init__( + self, + input_space: gym.spaces.Space, + action_space: gym.spaces.Space, + num_outputs: int, + *, + name: str = "", + fcnet_hiddens: Optional[Sequence[int]] = (), + fcnet_activation: Optional[str] = None, + post_fcnet_hiddens: Optional[Sequence[int]] = (), + post_fcnet_activation: Optional[str] = None, + no_final_linear: bool = False, + vf_share_layers: bool = False, + free_log_std: bool = False, + **kwargs, + ): + super().__init__(name=name) + + hiddens = list(fcnet_hiddens or ()) + \ + list(post_fcnet_hiddens or ()) + activation = fcnet_activation + if not fcnet_hiddens: + activation = post_fcnet_activation + activation = get_activation_fn(activation) + + # Generate free-floating bias variables for the second half of + # the outputs. + if free_log_std: + assert num_outputs % 2 == 0, ( + "num_outputs must be divisible by two", num_outputs) + num_outputs = num_outputs // 2 + self.log_std_var = tf.Variable( + [0.0] * num_outputs, dtype=tf.float32, name="log_std") + + # We are using obs_flat, so take the flattened shape as input. + inputs = tf.keras.layers.Input( + shape=(int(np.product(input_space.shape)), ), name="observations") + # Last hidden layer output (before logits outputs). + last_layer = inputs + # The action distribution outputs. + logits_out = None + i = 1 + + # Create layers 0 to second-last. + for size in hiddens[:-1]: + last_layer = tf.keras.layers.Dense( + size, + name="fc_{}".format(i), + activation=activation, + kernel_initializer=normc_initializer(1.0))(last_layer) + i += 1 + + # The last layer is adjusted to be of size num_outputs, but it's a + # layer with activation. + if no_final_linear and num_outputs: + logits_out = tf.keras.layers.Dense( + num_outputs, + name="fc_out", + activation=activation, + kernel_initializer=normc_initializer(1.0))(last_layer) + # Finish the layers with the provided sizes (`hiddens`), plus - + # iff num_outputs > 0 - a last linear layer of size num_outputs. + else: + if len(hiddens) > 0: + last_layer = tf.keras.layers.Dense( + hiddens[-1], + name="fc_{}".format(i), + activation=activation, + kernel_initializer=normc_initializer(1.0))(last_layer) + if num_outputs: + logits_out = tf.keras.layers.Dense( + num_outputs, + name="fc_out", + activation=None, + kernel_initializer=normc_initializer(0.01))(last_layer) + # Adjust num_outputs to be the number of nodes in the last layer. + else: + self.num_outputs = ( + [int(np.product(input_space.shape))] + hiddens[-1:])[-1] + + # Concat the log std vars to the end of the state-dependent means. + if free_log_std and logits_out is not None: + + def tiled_log_std(x): + return tf.tile( + tf.expand_dims(self.log_std_var, 0), [tf.shape(x)[0], 1]) + + log_std_out = tf.keras.layers.Lambda(tiled_log_std)(inputs) + logits_out = tf.keras.layers.Concatenate(axis=1)( + [logits_out, log_std_out]) + + last_vf_layer = None + if not vf_share_layers: + # Build a parallel set of hidden layers for the value net. + last_vf_layer = inputs + i = 1 + for size in hiddens: + last_vf_layer = tf.keras.layers.Dense( + size, + name="fc_value_{}".format(i), + activation=activation, + kernel_initializer=normc_initializer(1.0))(last_vf_layer) + i += 1 + + value_out = tf.keras.layers.Dense( + 1, + name="value_out", + activation=None, + kernel_initializer=normc_initializer(0.01))( + last_vf_layer if last_vf_layer is not None else last_layer) + + self.base_model = tf.keras.Model( + inputs, [(logits_out + if logits_out is not None else last_layer), value_out]) + + def call(self, input_dict: Dict[str, TensorType]) -> \ + Tuple[TensorType, List[TensorType], TensorType]: + model_out, value_out = self.base_model(input_dict["obs"]) + return model_out, [], { + SampleBatch.VF_PREDS: tf.reshape(value_out, [-1]) + } diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index cabdf809fe715..bb81f70e2727a 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -146,6 +146,7 @@ def __init__( self._loss_fn = loss_fn self._stats_fn = stats_fn self._grad_stats_fn = grad_stats_fn + self._seq_lens = None dist_class = dist_inputs = None if action_sampler_fn or action_distribution_fn: @@ -181,6 +182,7 @@ def __init__( v for k, v in existing_inputs.items() if k.startswith("state_in_") ] + # Placeholder for RNN time-chunk valid lengths. if self._state_inputs: self._seq_lens = existing_inputs["seq_lens"] else: @@ -191,12 +193,19 @@ def __init__( ) for k, vr in self.model.view_requirements.items() if k.startswith("state_in_") ] + # Placeholder for RNN time-chunk valid lengths. + if self._state_inputs: + self._seq_lens = tf1.placeholder( + dtype=tf.int32, shape=[None], name="seq_lens") # Use default settings. # Add NEXT_OBS, STATE_IN_0.., and others. self.view_requirements = self._get_default_view_requirements() # Combine view_requirements for Model and Policy. self.view_requirements.update(self.model.view_requirements) + # Disable env-info placeholder. + if SampleBatch.INFOS in self.view_requirements: + self.view_requirements[SampleBatch.INFOS].used_for_training = False # Setup standard placeholders. if existing_inputs is not None: @@ -225,9 +234,6 @@ def __init__( explore = tf1.placeholder_with_default( True, (), name="is_exploring") - # Placeholder for RNN time-chunk valid lengths. - self._seq_lens = tf1.placeholder( - dtype=tf.int32, shape=[None], name="seq_lens") # Placeholder for `is_training` flag. self._input_dict["is_training"] = self._get_is_training_placeholder() @@ -289,8 +295,12 @@ def __init__( # Default distribution generation behavior: # Pass through model. E.g., PG, PPO. else: - dist_inputs, self._state_out = self.model( - self._input_dict, self._state_inputs, self._seq_lens) + if isinstance(self.model, tf.keras.Model): + dist_inputs, self._state_out, self._extra_action_fetches =\ + self.model(self._input_dict) + else: + dist_inputs, self._state_out = self.model( + self._input_dict, self._state_inputs, self._seq_lens) action_dist = dist_class(dist_inputs, self.model) @@ -382,7 +392,7 @@ def copy(self, ]) instance._loss_input_dict = input_dict - loss = instance._do_loss_init(input_dict) + loss = instance._do_loss_init(SampleBatch(input_dict)) loss_inputs = [ (k, existing_inputs[i]) for i, k in enumerate(self._loss_input_dict_no_rnn.keys()) @@ -446,7 +456,7 @@ def _get_input_dict_and_dummy_batch(self, view_requirements, dummy_batch = self._get_dummy_batch_from_view_requirements( batch_size=32) - return input_dict, dummy_batch + return SampleBatch(input_dict, _seq_lens=self._seq_lens), dummy_batch def _initialize_loss_from_dummy_batch( self, auto_remove_unneeded_view_reqs: bool = True, @@ -588,6 +598,8 @@ def _do_loss_init(self, train_batch: SampleBatch): loss = self._loss_fn(self, self.model, self.dist_class, train_batch) if self._stats_fn: self._stats_fetches.update(self._stats_fn(self, train_batch)) - # override the update ops to be those of the model - self._update_ops = self.model.update_ops() + # Override the update ops to be those of the model. + self._update_ops = [] + if not isinstance(self.model, tf.keras.Model): + self._update_ops = self.model.update_ops() return loss diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 45c4a28c1210c..8837fbcce88ad 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -475,6 +475,9 @@ def _compute_action_helper(self, input_dict, state_batches, episodes, seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches \ else None + # Add default and custom fetches. + extra_fetches = {} + # Use Exploration object. with tf.variable_creator_scope(_disallow_var_creation): if action_sampler_fn: @@ -521,6 +524,11 @@ def _compute_action_helper(self, input_dict, state_batches, episodes, is_training=False) else: raise e + elif isinstance(self.model, tf.keras.Model): + input_dict = SampleBatch(input_dict, seq_lens=seq_lens) + self._lazy_tensor_dict(input_dict) + dist_inputs, state_out, extra_fetches = \ + self.model(input_dict) else: dist_inputs, state_out = self.model( input_dict, state_batches, seq_lens) @@ -533,8 +541,6 @@ def _compute_action_helper(self, input_dict, state_batches, episodes, timestep=timestep, explore=explore) - # Add default and custom fetches. - extra_fetches = {} # Action-logp and action-prob. if logp is not None: extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp) @@ -651,7 +657,10 @@ def set_state(self, state): def variables(self): """Return the list of all savable variables for this policy.""" - return self.model.variables() + if isinstance(self.model, tf.keras.Model): + return self.model.variables + else: + return self.model.variables() @override(Policy) def is_recurrent(self): @@ -704,7 +713,10 @@ def _compute_gradients(self, samples): with tf.GradientTape(persistent=gradients_fn is not None) as tape: loss = loss_fn(self, self.model, self.dist_class, samples) - variables = self.model.trainable_variables() + if isinstance(self.model, tf.keras.Model): + variables = self.model.trainable_variables + else: + variables = self.model.trainable_variables() if gradients_fn: diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 52a12eb060f39..c148d9709ad3b 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -750,7 +750,7 @@ def _get_dummy_batch_from_view_requirements( # Due to different view requirements for the different columns, # columns in the resulting batch may not all have the same batch size. - return SampleBatch(ret, _dont_check_lens=True) + return SampleBatch(ret) def _update_model_view_requirements_from_init_state(self): """Uses Model's (or this Policy's) init state to add needed ViewReqs. @@ -762,12 +762,21 @@ def _update_model_view_requirements_from_init_state(self): self._model_init_state_automatically_added = True model = getattr(self, "model", None) obj = model or self + if model and not hasattr(model, "view_requirements"): + model.view_requirements = { + SampleBatch.OBS: ViewRequirement(space=self.observation_space) + } + view_reqs = obj.view_requirements # Add state-ins to this model's view. - for i, state in enumerate(obj.get_initial_state()): + init_state = [] + if hasattr(obj, "get_initial_state") and callable( + obj.get_initial_state): + init_state = obj.get_initial_state() + else: + obj.get_initial_state = lambda: [] + for i, state in enumerate(init_state): space = Box(-1.0, 1.0, shape=state.shape) if \ hasattr(state, "shape") else state - view_reqs = model.view_requirements if model else \ - self.view_requirements view_reqs["state_in_{}".format(i)] = ViewRequirement( "state_out_{}".format(i), shift=-1, diff --git a/rllib/policy/rnn_sequencing.py b/rllib/policy/rnn_sequencing.py index a139fedb3250a..ca2eba4cf6952 100644 --- a/rllib/policy/rnn_sequencing.py +++ b/rllib/policy/rnn_sequencing.py @@ -416,8 +416,7 @@ def timeslice_along_seq_lens_with_overlap( i += 1 key = "state_in_{}".format(i) - timeslices.append( - SampleBatch(data, _seq_lens=[end - begin], _dont_check_lens=True)) + timeslices.append(SampleBatch(data, _seq_lens=[end - begin])) # Zero-pad each slice if necessary. if zero_pad_max_seq_len > 0: diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index f3be410160711..96b19270d6abf 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -72,7 +72,6 @@ def __init__(self, *args, **kwargs): if isinstance(self.seq_lens, list): self.seq_lens = np.array(self.seq_lens, dtype=np.int32) - self.dont_check_lens = kwargs.pop("_dont_check_lens", False) self.max_seq_len = kwargs.pop("_max_seq_len", None) if self.max_seq_len is None and self.seq_lens is not None and \ not (tf and tf.is_tensor(self.seq_lens)) and \ @@ -106,20 +105,12 @@ def __init__(self, *args, **kwargs): if isinstance(v, list): self[k] = np.array(v) - if not lengths: - raise ValueError("Empty sample batch") - - if not self.dont_check_lens: - assert len(set(lengths)) == 1, \ - "Data columns must be same length, but lens are " \ - "{}".format(lengths) - if self.seq_lens is not None and \ not (tf and tf.is_tensor(self.seq_lens)) and \ len(self.seq_lens) > 0: self.count = sum(self.seq_lens) else: - self.count = lengths[0] + self.count = lengths[0] if lengths else 0 @PublicAPI def __len__(self): @@ -163,7 +154,6 @@ def concat_samples(samples: List["SampleBatch"]) -> \ out, _seq_lens=np.array(seq_lens, dtype=np.int32), _time_major=concat_samples[0].time_major, - _dont_check_lens=True, _zero_padded=zero_padded, _max_seq_len=max_seq_len, ) @@ -213,7 +203,7 @@ def copy(self, shallow: bool = False) -> "SampleBatch": for (k, v) in self.items() }, _seq_lens=self.seq_lens, - _dont_check_lens=self.dont_check_lens) + ) copy_.set_get_interceptor(self.get_interceptor) return copy_ @@ -355,7 +345,7 @@ def slice(self, start: int, end: int) -> "SampleBatch": data, _seq_lens=np.array(seq_lens, dtype=np.int32), _time_major=self.time_major, - _dont_check_lens=True) + ) else: return SampleBatch( {k: v[start:end] @@ -592,6 +582,71 @@ def data(self): old="SampleBatch.data[..]", new="SampleBatch[..]", error=False) return self + # TODO: (sven) Experimental method. + def get_single_step_input_dict(self, view_requirements, index="last"): + """Creates single ts SampleBatch at given index from `self`. + + For usage as input-dict for model calls. + + Args: + sample_batch (SampleBatch): A single-trajectory SampleBatch object + to generate the compute_actions input dict from. + index (Union[int, str]): An integer index value indicating the + position in the trajectory for which to generate the + compute_actions input dict. Set to "last" to generate the dict + at the very end of the trajectory (e.g. for value estimation). + Note that "last" is different from -1, as "last" will use the + final NEXT_OBS as observation input. + + Returns: + SampleBatch: The (single-timestep) input dict for ModelV2 calls. + """ + last_mappings = { + SampleBatch.OBS: SampleBatch.NEXT_OBS, + SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS, + SampleBatch.PREV_REWARDS: SampleBatch.REWARDS, + } + + input_dict = {} + for view_col, view_req in view_requirements.items(): + # Create batches of size 1 (single-agent input-dict). + data_col = view_req.data_col or view_col + if index == "last": + data_col = last_mappings.get(data_col, data_col) + # Range needed. + if view_req.shift_from is not None: + data = self[view_col][-1] + traj_len = len(self[data_col]) + missing_at_end = traj_len % view_req.batch_repeat_value + obs_shift = -1 if data_col in [ + SampleBatch.OBS, SampleBatch.NEXT_OBS + ] else 0 + from_ = view_req.shift_from + obs_shift + to_ = view_req.shift_to + obs_shift + 1 + if to_ == 0: + to_ = None + input_dict[view_col] = np.array([ + np.concatenate( + [data, + self[data_col][-missing_at_end:]])[from_:to_] + ]) + # Single index. + else: + data = self[data_col][-1] + input_dict[view_col] = np.array([data]) + else: + # Index range. + if isinstance(index, tuple): + data = self[data_col][index[0]:index[1] + + 1 if index[1] != -1 else None] + input_dict[view_col] = np.array([data]) + # Single index. + else: + input_dict[view_col] = self[data_col][ + index:index + 1 if index != -1 else None] + + return SampleBatch(input_dict, _seq_lens=np.array([1], dtype=np.int32)) + @PublicAPI class MultiAgentBatch: diff --git a/rllib/policy/tests/test_sample_batch.py b/rllib/policy/tests/test_sample_batch.py index 9433b77f9a832..ebfd4c9b97f6b 100644 --- a/rllib/policy/tests/test_sample_batch.py +++ b/rllib/policy/tests/test_sample_batch.py @@ -20,9 +20,9 @@ def test_dict_properties_of_sample_batches(self): "b": np.array([[0.1, 0.2], [0.3, 0.4]]), "c": True, } - batch = SampleBatch(base_dict, _dont_check_lens=True) + batch = SampleBatch(base_dict) try: - SampleBatch(base_dict, _dont_check_lens=False) + SampleBatch(base_dict) except AssertionError: pass # expected keys_ = list(base_dict.keys()) diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index 2b0d1b3754254..b6e01a24f2128 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -156,9 +156,9 @@ def __init__(self, self.view_requirements[ SampleBatch.INFOS].used_for_compute_actions = False - assert model is None or isinstance(model, ModelV2), \ - "Model classes for TFPolicy other than `ModelV2` not allowed! " \ - "You passed in {}.".format(model) + assert model is None or isinstance(model, (ModelV2, tf.keras.Model)), \ + "Model classes for TFPolicy other than `ModelV2|tf.keras.Model` " \ + "not allowed! You passed in {}.".format(model) self.model = model # Auto-update model's inference view requirements, if recurrent. if self.model is not None: @@ -227,7 +227,10 @@ def __init__(self, def variables(self): """Return the list of all savable variables for this policy.""" - return self.model.variables() + if isinstance(self.model, tf.keras.Model): + return self.model.variables + else: + return self.model.variables() def get_placeholder(self, name) -> "tf1.placeholder": """Returns the given action or loss input placeholder by name. @@ -281,7 +284,7 @@ def _initialize_loss(self, loss: TensorType, for i, ph in enumerate(self._state_inputs): self._loss_input_dict["state_in_{}".format(i)] = ph - if self.model: + if self.model and not isinstance(self.model, tf.keras.Model): self._loss = self.model.custom_loss(loss, self._loss_input_dict) self._stats_fetches.update({ "model": self.model.metrics() if isinstance( @@ -919,7 +922,7 @@ def _get_loss_inputs_dict(self, train_batch: SampleBatch, shuffle: bool): # Mark the batch as "is_training" so the Model can use this # information. - train_batch["is_training"] = True + train_batch.is_training = True # Build the feed dict from the batch. feed_dict = {} diff --git a/rllib/policy/tf_policy_template.py b/rllib/policy/tf_policy_template.py index 817646ad656c6..ed02657e9fd78 100644 --- a/rllib/policy/tf_policy_template.py +++ b/rllib/policy/tf_policy_template.py @@ -214,11 +214,16 @@ def before_loss_init_wrapper(policy, obs_space, action_space, config): if before_loss_init: before_loss_init(policy, obs_space, action_space, config) + if extra_action_out_fn is None: - policy._extra_action_fetches = {} + extra_action_fetches = {} + else: + extra_action_fetches = extra_action_out_fn(policy) + + if hasattr(policy, "_extra_action_fetches"): + policy._extra_action_fetches.update(extra_action_fetches) else: - policy._extra_action_fetches = extra_action_out_fn(policy) - policy._extra_action_fetches = extra_action_out_fn(policy) + policy._extra_action_fetches = extra_action_fetches DynamicTFPolicy.__init__( self, diff --git a/rllib/utils/exploration/random.py b/rllib/utils/exploration/random.py index 524cf77b6b782..603f542595b03 100644 --- a/rllib/utils/exploration/random.py +++ b/rllib/utils/exploration/random.py @@ -63,7 +63,8 @@ def true_fn(): batch_size = 1 req = force_tuple( action_dist.required_model_output_shape( - self.action_space, self.model.model_config)) + self.action_space, getattr(self.model, "model_config", + None))) # Add a batch dimension? if len(action_dist.inputs.shape) == len(req) + 1: batch_size = tf.shape(action_dist.inputs)[0] @@ -138,7 +139,8 @@ def get_torch_exploration_action(self, action_dist: ActionDistribution, if explore: req = force_tuple( action_dist.required_model_output_shape( - self.action_space, self.model.model_config)) + self.action_space, getattr(self.model, "model_config", + None))) # Add a batch dimension? if len(action_dist.inputs.shape) == len(req) + 1: batch_size = action_dist.inputs.shape[0]