Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Support native tf.keras.Model (milestone toward obsoleting ModelV2 class). #14684

Merged
merged 48 commits into from
Apr 27, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
7c5f7ce
WIP.
sven1977 Dec 23, 2020
8f74a7b
WIP.
sven1977 Dec 23, 2020
a8cb715
WIP.
sven1977 Dec 23, 2020
43fcfc5
WIP.
sven1977 Dec 23, 2020
8f95df7
Fix.
sven1977 Dec 24, 2020
ef80373
Merge branch 'bc_marwil_minor_cleanups' into obsolete_usage_tracking_…
sven1977 Dec 24, 2020
8dd46c2
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Dec 26, 2020
40a6d8b
Fixes and LINT.
sven1977 Dec 26, 2020
e6e8ecb
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Dec 27, 2020
bb728ef
Fixes and LINT.
sven1977 Dec 27, 2020
4f16efe
WIP.
sven1977 Dec 28, 2020
56a4e1b
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Mar 13, 2021
371a298
wip
sven1977 Mar 13, 2021
bdb7c93
wip
sven1977 Mar 14, 2021
f03f49d
wip
sven1977 Mar 14, 2021
d4e31de
wip
sven1977 Mar 14, 2021
1988c53
wip
sven1977 Mar 14, 2021
493424a
wip
sven1977 Mar 14, 2021
a74d668
wip
sven1977 Mar 14, 2021
a979ca8
wip
sven1977 Mar 15, 2021
e93354a
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Mar 15, 2021
3fa7e55
wip
sven1977 Mar 15, 2021
deaa7f5
wip
sven1977 Mar 15, 2021
287ea97
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Mar 18, 2021
214332a
wip.
sven1977 Mar 18, 2021
94f8f71
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Mar 23, 2021
b92c702
fix.
sven1977 Mar 23, 2021
b393f03
wip.
sven1977 Mar 23, 2021
d33afea
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Mar 27, 2021
e41d4ef
wip.
sven1977 Mar 27, 2021
eb2ab5d
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Apr 10, 2021
96bacbf
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Apr 10, 2021
0a5bf55
wip.
sven1977 Apr 10, 2021
32f5140
fix
sven1977 Apr 10, 2021
131ec8f
fix and LINT.
sven1977 Apr 11, 2021
7c11cec
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Apr 11, 2021
76cdc92
fixes.
sven1977 Apr 11, 2021
b8ea5a1
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Apr 11, 2021
301e6f9
fixes.
sven1977 Apr 11, 2021
f097b81
fix and LINT.
sven1977 Apr 11, 2021
6ddb1a6
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Apr 12, 2021
86c6349
wip.
sven1977 Apr 13, 2021
e07ed80
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Apr 13, 2021
4677ee5
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Apr 15, 2021
400092c
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Apr 15, 2021
469fca6
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Apr 15, 2021
ab31bf5
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Apr 19, 2021
d768a0e
Merge branch 'master' of https://github.com/ray-project/ray into obso…
sven1977 Apr 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wip.
  • Loading branch information
sven1977 committed Mar 23, 2021
commit b393f037e6fddd881f7f2bee1e901aa10bd865db
34 changes: 25 additions & 9 deletions rllib/agents/ppo/ppo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,29 @@


def ppo_surrogate_loss(
policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution],
policy: Policy, model: Union[ModelV2, tf.keras.Model],
dist_class: Type[TFActionDistribution],
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
"""Constructs the loss for Proximal Policy Objective.

Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
model (Union[ModelV2, tf.keras.Model]): The Model to calculate
the loss for.
dist_class (Type[ActionDistribution]: The action distr. class.
train_batch (SampleBatch): The training data.

Returns:
Union[TensorType, List[TensorType]]: A single loss tensor or a list
of loss tensors.
"""
logits, state = model(train_batch)
if isinstance(model, tf.keras.Model):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New model API: native + return 3 values (model-out, state-outs, extra-outs (e.g. vf)).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming ModelV2 will be deprecated soon :)

logits, state, extra_outs = model(train_batch)
value_fn_out = extra_outs[SampleBatch.VF_PREDS]
else:
logits, state = model(train_batch)
value_fn_out = model.value_function()

curr_action_dist = dist_class(logits, model)

# RNN case: Mask away 0-padded chunks at end of time axis.
Expand Down Expand Up @@ -86,7 +94,6 @@ def reduce_mean_valid(t):

if policy.config["use_gae"]:
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
value_fn_out = model.value_function()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One reason for the new API is that we can get rid of these async value_function calls.

vf_loss1 = tf.math.square(value_fn_out -
train_batch[Postprocessing.VALUE_TARGETS])
vf_clipped = prev_value_fn_out + tf.clip_by_value(
Expand All @@ -112,6 +119,7 @@ def reduce_mean_valid(t):
policy._mean_vf_loss = mean_vf_loss
policy._mean_entropy = mean_entropy
policy._mean_kl = mean_kl
policy._value_fn_out = value_fn_out

return total_loss

Expand All @@ -134,14 +142,14 @@ def kl_and_loss_stats(policy: Policy,
"policy_loss": policy._mean_policy_loss,
"vf_loss": policy._mean_vf_loss,
"vf_explained_var": explained_variance(
train_batch[Postprocessing.VALUE_TARGETS],
policy.model.value_function()),
train_batch[Postprocessing.VALUE_TARGETS], policy._value_fn_out),
"kl": policy._mean_kl,
"entropy": policy._mean_entropy,
"entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64),
}


# TODO: (sven) Deprecate once we only allow native keras models.
def vf_preds_fetches(policy: Policy) -> Dict[str, TensorType]:
"""Defines extra fetches per action computation.

Expand All @@ -152,6 +160,10 @@ def vf_preds_fetches(policy: Policy) -> Dict[str, TensorType]:
Dict[str, TensorType]: Dict with extra tf fetches to perform per
action computation.
"""
# Keras models return values for each call in third return argument
# (dict).
if isinstance(policy.model, tf.keras.Model):
return {}
# Return value function outputs. VF estimates will hence be added to the
# SampleBatches produced by the sampler(s) to generate the train batches
# going into the loss function.
Expand Down Expand Up @@ -256,9 +268,13 @@ def __init__(self, obs_space, action_space, config):
@make_tf_callable(self.get_session())
def value(**input_dict):
input_dict = SampleBatch(input_dict)
model_out, _ = self.model(input_dict)
# [0] = remove the batch dim.
return self.model.value_function()[0]
if isinstance(self.model, tf.keras.Model):
_, _, extra_outs = self.model(input_dict)
return extra_outs[SampleBatch.VF_PREDS][0]
else:
model_out, _ = self.model(input_dict)
# [0] = remove the batch dim.
return self.model.value_function()[0]

# TODO: (sven) Remove once trajectory view API is all-algo default.
else:
Expand Down
5 changes: 2 additions & 3 deletions rllib/evaluation/collectors/simple_list_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch:

# Due to possible batch-repeats > 1, columns in the resulting batch
# may not all have the same batch size.
batch = SampleBatch(batch_data, _dont_check_lens=True)
batch = SampleBatch(batch_data)

# Add EPS_ID and UNROLL_ID to batch.
batch[SampleBatch.EPS_ID] = np.repeat(self.episode_id, batch.count)
Expand Down Expand Up @@ -366,8 +366,7 @@ def build(self):
this policy.
"""
# Create batch from our buffers.
batch = SampleBatch(
self.buffers, _seq_lens=self.seq_lens, _dont_check_lens=True)
batch = SampleBatch(self.buffers, _seq_lens=self.seq_lens)
# Clear buffers for future samples.
self.buffers.clear()
# Reset agent steps to 0 and seq-lens to empty list.
Expand Down
14 changes: 6 additions & 8 deletions rllib/examples/models/modelv3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np

from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf, try_import_torch

tf1, tf, tfv = try_import_tf()
Expand All @@ -13,11 +14,11 @@ def __init__(self,
input_space,
action_space,
num_outputs,
model_config,
name,
*,
name="",
hiddens_size=256,
cell_size=64):
super(RNNModel, self).__init__()
super().__init__(name=name)

self.cell_size = cell_size

Expand Down Expand Up @@ -46,14 +47,11 @@ def call(self, sample_batch):
)
lstm_out = tf.reshape(lstm_out, [-1, lstm_out.shape.as_list()[2]])
logits = self.logits(lstm_out)
self._value_out = self.values(lstm_out)
return logits, [h, c]
values = tf.reshape(self.values(lstm_out), [-1])
return logits, [h, c], {SampleBatch.VF_PREDS: values}

def get_initial_state(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to get rid of this as well in the future via the trajectory view API.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some users would want different types of initial states (e.g. xavier initalization) right?

return [
np.zeros(self.cell_size, np.float32),
np.zeros(self.cell_size, np.float32),
]

def value_function(self):
return tf.reshape(self._value_out, [-1])
64 changes: 43 additions & 21 deletions rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
TorchDeterministic, TorchDiagGaussian, \
TorchMultiActionDistribution, TorchMultiCategorical
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, \
deprecation_warning
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.spaces.simplex import Simplex
Expand Down Expand Up @@ -418,26 +419,45 @@ def track_var_creation(next_creator, **kw):
return v

with tf.variable_creator_scope(track_var_creation):
# Try calling with kwargs first (custom ModelV2 should
# accept these as kwargs, not get them from
# config["custom_model_config"] anymore).
try:
instance = model_cls(obs_space, action_space,
num_outputs, model_config, name,
**customized_model_kwargs)
except TypeError as e:
# Keyword error: Try old way w/o kwargs.
if "__init__() got an unexpected " in e.args[0]:
instance = model_cls(obs_space, action_space,
num_outputs, model_config,
name, **model_kwargs)
logger.warning(
"Custom ModelV2 should accept all custom "
"options as **kwargs, instead of expecting"
" them in config['custom_model_config']!")
# Other error -> re-raise.
else:
raise e
if issubclass(model_cls, tf.keras.Model):
instance = model_cls(
input_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
name=name,
**customized_model_kwargs,
)
else:
# Try calling with kwargs first (custom ModelV2 should
# accept these as kwargs, not get them from
# config["custom_model_config"] anymore).
try:
instance = model_cls(
obs_space,
action_space,
num_outputs,
model_config,
name,
**customized_model_kwargs,
)
except TypeError as e:
# Keyword error: Try old way w/o kwargs.
if "__init__() got an unexpected " in e.args[0]:
instance = model_cls(
obs_space,
action_space,
num_outputs,
model_config,
name,
**model_kwargs,
)
logger.warning(
"Custom ModelV2 should accept all custom "
"options as **kwargs, instead of expecting"
" them in config['custom_model_config']!")
# Other error -> re-raise.
else:
raise e

# User still registered TFModelV2's variables: Check, whether
# ok.
Expand Down Expand Up @@ -666,6 +686,8 @@ def register_custom_model(model_name: str, model_class: type) -> None:
model_name (str): Name to register the model under.
model_class (type): Python class of the model.
"""
if issubclass(model_class, tf.keras.Model):
deprecation_warning(old="register_custom_model", error=False)
_global_registry.register(RLLIB_MODEL, model_name, model_class)

@staticmethod
Expand Down
8 changes: 5 additions & 3 deletions rllib/models/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.examples.models.modelv3 import RNNModel
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.utils.framework import try_import_tf
Expand Down Expand Up @@ -65,11 +64,14 @@ def test_tf_modelv2(self):
self.assertTrue("fc_net.base_model.value_out.bias:0" in vars)

def test_modelv3(self):
ModelCatalog.register_custom_model("keras_model", RNNModel)
config = {
"env": "CartPole-v0",
"model": {
"custom_model": "keras_model",
"custom_model": RNNModel,
"custom_model_config": {
"hiddens_size": 64,
"cell_size": 128,
},
},
"num_workers": 0,
}
Expand Down
3 changes: 2 additions & 1 deletion rllib/policy/dynamic_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def __init__(
# Pass through model. E.g., PG, PPO.
else:
if isinstance(self.model, tf.keras.Model):
dist_inputs, self._state_out = self.model(self._input_dict)
dist_inputs, self._state_out, self._extra_action_fetches =\
self.model(self._input_dict)
else:
dist_inputs, self._state_out = self.model(
self._input_dict, self._state_inputs, self._seq_lens)
Expand Down
2 changes: 1 addition & 1 deletion rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ def _get_dummy_batch_from_view_requirements(

# Due to different view requirements for the different columns,
# columns in the resulting batch may not all have the same batch size.
return SampleBatch(ret, _dont_check_lens=True)
return SampleBatch(ret)

def _update_model_view_requirements_from_init_state(self):
"""Uses Model's (or this Policy's) init state to add needed ViewReqs.
Expand Down
3 changes: 1 addition & 2 deletions rllib/policy/rnn_sequencing.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,7 @@ def timeslice_along_seq_lens_with_overlap(
i += 1
key = "state_in_{}".format(i)

timeslices.append(
SampleBatch(data, _seq_lens=[end - begin], _dont_check_lens=True))
timeslices.append(SampleBatch(data, _seq_lens=[end - begin]))

# Zero-pad each slice if necessary.
if zero_pad_max_seq_len > 0:
Expand Down
16 changes: 3 additions & 13 deletions rllib/policy/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def __init__(self, *args, **kwargs):
"seq_lens", None))
if isinstance(self.seq_lens, list):
self.seq_lens = np.array(self.seq_lens, dtype=np.int32)
self.dont_check_lens = kwargs.pop("_dont_check_lens", False)
self.max_seq_len = kwargs.pop("_max_seq_len", None)
if self.max_seq_len is None and self.seq_lens is not None and \
not (tf and tf.is_tensor(self.seq_lens)) and \
Expand Down Expand Up @@ -104,20 +103,12 @@ def __init__(self, *args, **kwargs):
if isinstance(v, list):
self[k] = np.array(v)

if not lengths:
raise ValueError("Empty sample batch")

if not self.dont_check_lens:
assert len(set(lengths)) == 1, \
"Data columns must be same length, but lens are " \
"{}".format(lengths)

if self.seq_lens is not None and \
not (tf and tf.is_tensor(self.seq_lens)) and \
len(self.seq_lens) > 0:
self.count = sum(self.seq_lens)
else:
self.count = lengths[0]
self.count = lengths[0] if lengths else 0

@PublicAPI
def __len__(self):
Expand Down Expand Up @@ -161,7 +152,6 @@ def concat_samples(samples: List["SampleBatch"]) -> \
out,
_seq_lens=np.array(seq_lens, dtype=np.int32),
_time_major=concat_samples[0].time_major,
_dont_check_lens=True,
_zero_padded=zero_padded,
_max_seq_len=max_seq_len,
)
Expand Down Expand Up @@ -211,7 +201,7 @@ def copy(self, shallow: bool = False) -> "SampleBatch":
for (k, v) in self.items()
},
_seq_lens=self.seq_lens,
_dont_check_lens=self.dont_check_lens)
)
copy_.set_get_interceptor(self.get_interceptor)
return copy_

Expand Down Expand Up @@ -353,7 +343,7 @@ def slice(self, start: int, end: int) -> "SampleBatch":
data,
_seq_lens=np.array(seq_lens, dtype=np.int32),
_time_major=self.time_major,
_dont_check_lens=True)
)
else:
return SampleBatch(
{k: v[start:end]
Expand Down
4 changes: 2 additions & 2 deletions rllib/policy/tests/test_sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def test_dict_properties_of_sample_batches(self):
"b": np.array([[0.1, 0.2], [0.3, 0.4]]),
"c": True,
}
batch = SampleBatch(base_dict, _dont_check_lens=True)
batch = SampleBatch(base_dict)
try:
SampleBatch(base_dict, _dont_check_lens=False)
SampleBatch(base_dict)
except AssertionError:
pass # expected
keys_ = list(base_dict.keys())
Expand Down
11 changes: 8 additions & 3 deletions rllib/policy/tf_policy_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,16 @@ def before_loss_init_wrapper(policy, obs_space, action_space,
config):
if before_loss_init:
before_loss_init(policy, obs_space, action_space, config)

if extra_action_out_fn is None:
policy._extra_action_fetches = {}
extra_action_fetches = {}
else:
extra_action_fetches = extra_action_out_fn(policy)

if hasattr(policy, "_extra_action_fetches"):
policy._extra_action_fetches.update(extra_action_fetches)
else:
policy._extra_action_fetches = extra_action_out_fn(policy)
policy._extra_action_fetches = extra_action_out_fn(policy)
policy._extra_action_fetches = extra_action_fetches

DynamicTFPolicy.__init__(
self,
Expand Down