-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Support native tf.keras.Model (milestone toward obsoleting ModelV2 class). #14684
Changes from 1 commit
7c5f7ce
8f74a7b
a8cb715
43fcfc5
8f95df7
ef80373
8dd46c2
40a6d8b
e6e8ecb
bb728ef
4f16efe
56a4e1b
371a298
bdb7c93
f03f49d
d4e31de
1988c53
493424a
a74d668
a979ca8
e93354a
3fa7e55
deaa7f5
287ea97
214332a
94f8f71
b92c702
b393f03
d33afea
e41d4ef
eb2ab5d
96bacbf
0a5bf55
32f5140
131ec8f
7c11cec
76cdc92
b8ea5a1
301e6f9
f097b81
6ddb1a6
86c6349
e07ed80
4677ee5
400092c
469fca6
ab31bf5
d768a0e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,21 +29,29 @@ | |
|
||
|
||
def ppo_surrogate_loss( | ||
policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], | ||
policy: Policy, model: Union[ModelV2, tf.keras.Model], | ||
dist_class: Type[TFActionDistribution], | ||
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: | ||
"""Constructs the loss for Proximal Policy Objective. | ||
|
||
Args: | ||
policy (Policy): The Policy to calculate the loss for. | ||
model (ModelV2): The Model to calculate the loss for. | ||
model (Union[ModelV2, tf.keras.Model]): The Model to calculate | ||
the loss for. | ||
dist_class (Type[ActionDistribution]: The action distr. class. | ||
train_batch (SampleBatch): The training data. | ||
|
||
Returns: | ||
Union[TensorType, List[TensorType]]: A single loss tensor or a list | ||
of loss tensors. | ||
""" | ||
logits, state = model(train_batch) | ||
if isinstance(model, tf.keras.Model): | ||
logits, state, extra_outs = model(train_batch) | ||
value_fn_out = extra_outs[SampleBatch.VF_PREDS] | ||
else: | ||
logits, state = model(train_batch) | ||
value_fn_out = model.value_function() | ||
|
||
curr_action_dist = dist_class(logits, model) | ||
|
||
# RNN case: Mask away 0-padded chunks at end of time axis. | ||
|
@@ -86,7 +94,6 @@ def reduce_mean_valid(t): | |
|
||
if policy.config["use_gae"]: | ||
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] | ||
value_fn_out = model.value_function() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One reason for the new API is that we can get rid of these async |
||
vf_loss1 = tf.math.square(value_fn_out - | ||
train_batch[Postprocessing.VALUE_TARGETS]) | ||
vf_clipped = prev_value_fn_out + tf.clip_by_value( | ||
|
@@ -112,6 +119,7 @@ def reduce_mean_valid(t): | |
policy._mean_vf_loss = mean_vf_loss | ||
policy._mean_entropy = mean_entropy | ||
policy._mean_kl = mean_kl | ||
policy._value_fn_out = value_fn_out | ||
|
||
return total_loss | ||
|
||
|
@@ -134,14 +142,14 @@ def kl_and_loss_stats(policy: Policy, | |
"policy_loss": policy._mean_policy_loss, | ||
"vf_loss": policy._mean_vf_loss, | ||
"vf_explained_var": explained_variance( | ||
train_batch[Postprocessing.VALUE_TARGETS], | ||
policy.model.value_function()), | ||
train_batch[Postprocessing.VALUE_TARGETS], policy._value_fn_out), | ||
"kl": policy._mean_kl, | ||
"entropy": policy._mean_entropy, | ||
"entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64), | ||
} | ||
|
||
|
||
# TODO: (sven) Deprecate once we only allow native keras models. | ||
def vf_preds_fetches(policy: Policy) -> Dict[str, TensorType]: | ||
"""Defines extra fetches per action computation. | ||
|
||
|
@@ -152,6 +160,10 @@ def vf_preds_fetches(policy: Policy) -> Dict[str, TensorType]: | |
Dict[str, TensorType]: Dict with extra tf fetches to perform per | ||
action computation. | ||
""" | ||
# Keras models return values for each call in third return argument | ||
# (dict). | ||
if isinstance(policy.model, tf.keras.Model): | ||
return {} | ||
# Return value function outputs. VF estimates will hence be added to the | ||
# SampleBatches produced by the sampler(s) to generate the train batches | ||
# going into the loss function. | ||
|
@@ -256,9 +268,13 @@ def __init__(self, obs_space, action_space, config): | |
@make_tf_callable(self.get_session()) | ||
def value(**input_dict): | ||
input_dict = SampleBatch(input_dict) | ||
model_out, _ = self.model(input_dict) | ||
# [0] = remove the batch dim. | ||
return self.model.value_function()[0] | ||
if isinstance(self.model, tf.keras.Model): | ||
_, _, extra_outs = self.model(input_dict) | ||
return extra_outs[SampleBatch.VF_PREDS][0] | ||
else: | ||
model_out, _ = self.model(input_dict) | ||
# [0] = remove the batch dim. | ||
return self.model.value_function()[0] | ||
|
||
# TODO: (sven) Remove once trajectory view API is all-algo default. | ||
else: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import numpy as np | ||
|
||
from ray.rllib.policy.sample_batch import SampleBatch | ||
from ray.rllib.utils.framework import try_import_tf, try_import_torch | ||
|
||
tf1, tf, tfv = try_import_tf() | ||
|
@@ -13,11 +14,11 @@ def __init__(self, | |
input_space, | ||
action_space, | ||
num_outputs, | ||
model_config, | ||
name, | ||
*, | ||
name="", | ||
hiddens_size=256, | ||
cell_size=64): | ||
super(RNNModel, self).__init__() | ||
super().__init__(name=name) | ||
|
||
self.cell_size = cell_size | ||
|
||
|
@@ -46,14 +47,11 @@ def call(self, sample_batch): | |
) | ||
lstm_out = tf.reshape(lstm_out, [-1, lstm_out.shape.as_list()[2]]) | ||
logits = self.logits(lstm_out) | ||
self._value_out = self.values(lstm_out) | ||
return logits, [h, c] | ||
values = tf.reshape(self.values(lstm_out), [-1]) | ||
return logits, [h, c], {SampleBatch.VF_PREDS: values} | ||
|
||
def get_initial_state(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should be able to get rid of this as well in the future via the trajectory view API. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some users would want different types of initial states (e.g. xavier initalization) right? |
||
return [ | ||
np.zeros(self.cell_size, np.float32), | ||
np.zeros(self.cell_size, np.float32), | ||
] | ||
|
||
def value_function(self): | ||
return tf.reshape(self._value_out, [-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New model API: native + return 3 values (model-out, state-outs, extra-outs (e.g. vf)).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming ModelV2 will be deprecated soon :)