-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Support native tf.keras.Model (milestone toward obsoleting ModelV2 class). #14684
[RLlib] Support native tf.keras.Model (milestone toward obsoleting ModelV2 class). #14684
Conversation
…dict_via_sample_batch # Conflicts: # rllib/policy/sample_batch.py
…lete_usage_tracking_dict_via_sample_batch
…lete_usage_tracking_dict_via_sample_batch # Conflicts: # rllib/agents/marwil/marwil_tf_policy.py # rllib/agents/marwil/marwil_torch_policy.py # rllib/models/tf/recurrent_net.py # rllib/models/torch/recurrent_net.py # rllib/offline/json_reader.py # rllib/policy/sample_batch.py
…lete_usage_tracking_dict_via_sample_batch # Conflicts: # rllib/agents/ppo/ppo_tf_policy.py # rllib/models/modelv2.py # rllib/policy/dynamic_tf_policy.py # rllib/policy/eager_tf_policy.py # rllib/policy/policy.py # rllib/policy/sample_batch.py # rllib/policy/torch_policy.py # rllib/utils/tracking_dict.py
…lete_usage_tracking_dict_via_sample_batch
…lete_modelv2_class # Conflicts: # rllib/agents/ddpg/ddpg_torch_policy.py # rllib/evaluation/postprocessing.py # rllib/models/modelv2.py # rllib/policy/dynamic_tf_policy.py # rllib/policy/sample_batch.py
…lete_modelv2_class
…lete_modelv2_class
…lete_modelv2_class
…lete_modelv2_class
@@ -67,7 +67,8 @@ def postprocess_advantages(policy, | |||
# input_dict. | |||
# Create an input dict according to the Model's requirements. | |||
index = "last" if SampleBatch.NEXT_OBS in sample_batch.data else -1 | |||
input_dict = policy.model.get_input_dict(sample_batch, index=index) | |||
input_dict = sample_batch.get_single_step_input_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this helper method to SampleBatch (from ModelV2), where it's more fitting.
dist_class (Type[ActionDistribution]: The action distr. class. | ||
train_batch (SampleBatch): The training data. | ||
|
||
Returns: | ||
Union[TensorType, List[TensorType]]: A single loss tensor or a list | ||
of loss tensors. | ||
""" | ||
logits, state = model.from_batch(train_batch) | ||
if isinstance(model, tf.keras.Model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New model API: native + return 3 values (model-out, state-outs, extra-outs (e.g. vf)).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming ModelV2 will be deprecated soon :)
@@ -86,7 +94,6 @@ def reduce_mean_valid(t): | |||
|
|||
if policy.config["use_gae"]: | |||
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] | |||
value_fn_out = model.value_function() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One reason for the new API is that we can get rid of these async value_function
calls.
@@ -82,6 +82,9 @@ def test_ppo_compilation_and_lr_schedule(self): | |||
# Settings in case we use an LSTM. | |||
config["model"]["lstm_cell_size"] = 10 | |||
config["model"]["max_seq_len"] = 20 | |||
# Use default-native keras model whenever possible. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Testing, whether this flag works (try using native keras default model).
Currently only works for FCNets.
values = tf.reshape(self.values(lstm_out), [-1]) | ||
return logits, [h, c], {SampleBatch.VF_PREDS: values} | ||
|
||
def get_initial_state(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be able to get rid of this as well in the future via the trajectory view API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some users would want different types of initial states (e.g. xavier initalization) right?
@@ -333,71 +333,6 @@ def is_time_major(self) -> bool: | |||
""" | |||
return self.time_major is True | |||
|
|||
# TODO: (sven) Experimental method. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was moved to SampleBatch instead (makes more sense there).
SampleBatch converts itself into an single-timestep input-dict (which is also a SampleBatch).
@@ -588,6 +579,71 @@ def data(self): | |||
old="SampleBatch.data[..]", new="SampleBatch[..]", error=False) | |||
return self | |||
|
|||
# TODO: (sven) Experimental method. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved here from ModelV2 (makes more sense inside SampleBatch).
…lete_modelv2_class
…lete_modelv2_class
…lete_modelv2_class # Conflicts: # rllib/agents/marwil/marwil_tf_policy.py
…lete_modelv2_class
…lete_modelv2_class # Conflicts: # rllib/policy/torch_policy.py
…lete_modelv2_class # Conflicts: # rllib/agents/trainer.py
i = 1 | ||
|
||
# Create layers 0 to second-last. | ||
for size in hiddens[:-1]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this code copy pasted from the prior modelv2 fcnet for tf?
@sven1977 Line 741 in e973b72
Should be if if tf and issubclass(model_class, tf.keras.Model):
and Line 408 in e973b72
Should be if tf and framework not in ["tf", "tf2", "tfe"]
I'd patch this but I've managed to totally mangle my fork/local repo atm. |
This PR adds support for native keras models in TF- and eager Policies.
This is an experimental API extension that's fully backward compatible (ModelV2s will still work, of course!).
get_initial_state
method (this requirement may be removed as well to support functional model building or using tf.keras.Sequential). All other ModelV2 methods are no longer required.call
method must be implemented for the forward logic.NOTE: This PR is for TF only. A solution for supporting pure native torch.Modules is to follow.
Follow up PRs will include:
Why are these changes needed?
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.