Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Support native tf.keras.Model (milestone toward obsoleting ModelV2 class). #14684

Merged
merged 48 commits into from
Apr 27, 2021

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Mar 15, 2021

This PR adds support for native keras models in TF- and eager Policies.

This is an experimental API extension that's fully backward compatible (ModelV2s will still work, of course!).

  • Users can define a pure tf.keras.Model as their model and have the TF/eager policy use it via the "custom_model" config key.
  • The keras model only has to implement a get_initial_state method (this requirement may be removed as well to support functional model building or using tf.keras.Sequential). All other ModelV2 methods are no longer required.
  • The keras.Model call method must be implemented for the forward logic.
  • An example (RNN) model class has been added (examples/models/modelv3.py), as well as a test case confirming that PPO is learning CartPole using this approach.
  • A new (experimental) config flag "_use_default_native_models" (False by default) has been added to catalog.py. If True, RLlib will use built-in keras default models where possible (e.g. done in the PPO test). Setting this flag to True is save (RLlib will not try to use keras models where this is not supported, e.g. for torch, for VisionNets, etc..).

NOTE: This PR is for TF only. A solution for supporting pure native torch.Modules is to follow.

Follow up PRs will include:

  • Making all default tf nets native keras (VisionNet + RNN + LSTM-/attention-wrappers).
  • Support for PyTorch native Modules.
  • Examples for custom Sequential types, RNNs (how to specify initial states?), trajectory view API, etc..

Why are these changes needed?

Related issue number

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

…dict_via_sample_batch

# Conflicts:
#	rllib/policy/sample_batch.py
…lete_usage_tracking_dict_via_sample_batch
…lete_usage_tracking_dict_via_sample_batch

# Conflicts:
#	rllib/agents/marwil/marwil_tf_policy.py
#	rllib/agents/marwil/marwil_torch_policy.py
#	rllib/models/tf/recurrent_net.py
#	rllib/models/torch/recurrent_net.py
#	rllib/offline/json_reader.py
#	rllib/policy/sample_batch.py
…lete_usage_tracking_dict_via_sample_batch

# Conflicts:
#	rllib/agents/ppo/ppo_tf_policy.py
#	rllib/models/modelv2.py
#	rllib/policy/dynamic_tf_policy.py
#	rllib/policy/eager_tf_policy.py
#	rllib/policy/policy.py
#	rllib/policy/sample_batch.py
#	rllib/policy/torch_policy.py
#	rllib/utils/tracking_dict.py
…lete_usage_tracking_dict_via_sample_batch
…lete_modelv2_class

# Conflicts:
#	rllib/agents/ddpg/ddpg_torch_policy.py
#	rllib/evaluation/postprocessing.py
#	rllib/models/modelv2.py
#	rllib/policy/dynamic_tf_policy.py
#	rllib/policy/sample_batch.py
@sven1977 sven1977 marked this pull request as ready for review March 23, 2021 13:55
@sven1977 sven1977 requested a review from michaelzhiluo March 23, 2021 13:55
@sven1977 sven1977 added the tests-ok The tagger certifies test failures are unrelated and assumes personal liability. label Apr 13, 2021
@@ -67,7 +67,8 @@ def postprocess_advantages(policy,
# input_dict.
# Create an input dict according to the Model's requirements.
index = "last" if SampleBatch.NEXT_OBS in sample_batch.data else -1
input_dict = policy.model.get_input_dict(sample_batch, index=index)
input_dict = sample_batch.get_single_step_input_dict(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this helper method to SampleBatch (from ModelV2), where it's more fitting.

dist_class (Type[ActionDistribution]: The action distr. class.
train_batch (SampleBatch): The training data.

Returns:
Union[TensorType, List[TensorType]]: A single loss tensor or a list
of loss tensors.
"""
logits, state = model.from_batch(train_batch)
if isinstance(model, tf.keras.Model):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New model API: native + return 3 values (model-out, state-outs, extra-outs (e.g. vf)).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming ModelV2 will be deprecated soon :)

@@ -86,7 +94,6 @@ def reduce_mean_valid(t):

if policy.config["use_gae"]:
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
value_fn_out = model.value_function()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One reason for the new API is that we can get rid of these async value_function calls.

@@ -82,6 +82,9 @@ def test_ppo_compilation_and_lr_schedule(self):
# Settings in case we use an LSTM.
config["model"]["lstm_cell_size"] = 10
config["model"]["max_seq_len"] = 20
# Use default-native keras model whenever possible.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing, whether this flag works (try using native keras default model).
Currently only works for FCNets.

values = tf.reshape(self.values(lstm_out), [-1])
return logits, [h, c], {SampleBatch.VF_PREDS: values}

def get_initial_state(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to get rid of this as well in the future via the trajectory view API.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some users would want different types of initial states (e.g. xavier initalization) right?

@@ -333,71 +333,6 @@ def is_time_major(self) -> bool:
"""
return self.time_major is True

# TODO: (sven) Experimental method.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was moved to SampleBatch instead (makes more sense there).
SampleBatch converts itself into an single-timestep input-dict (which is also a SampleBatch).

@@ -588,6 +579,71 @@ def data(self):
old="SampleBatch.data[..]", new="SampleBatch[..]", error=False)
return self

# TODO: (sven) Experimental method.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved here from ModelV2 (makes more sense inside SampleBatch).

…lete_modelv2_class

# Conflicts:
#	rllib/agents/marwil/marwil_tf_policy.py
…lete_modelv2_class

# Conflicts:
#	rllib/policy/torch_policy.py
…lete_modelv2_class

# Conflicts:
#	rllib/agents/trainer.py
i = 1

# Create layers 0 to second-last.
for size in hiddens[:-1]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this code copy pasted from the prior modelv2 fcnet for tf?

@sven1977 sven1977 merged commit bb8a286 into ray-project:master Apr 27, 2021
@smorad
Copy link
Contributor

smorad commented May 6, 2021

@sven1977
This breaks rllib if tensorflow is not installed.

if issubclass(model_class, tf.keras.Model):

Should be if if tf and issubclass(model_class, tf.keras.Model):

and

if framework not in ["tf", "tf2", "tfe"] or \

Should be if tf and framework not in ["tf", "tf2", "tfe"]

I'd patch this but I've managed to totally mangle my fork/local repo atm.

@sven1977 sven1977 deleted the obsolete_modelv2_class branch June 2, 2023 20:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tests-ok The tagger certifies test failures are unrelated and assumes personal liability.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants