Skip to content

Commit

Permalink
[RLlib] Support native tf.keras.Model (milestone toward obsoleting Mo…
Browse files Browse the repository at this point in the history
…delV2 class). (ray-project#14684)
  • Loading branch information
sven1977 authored Apr 27, 2021
1 parent fb17ef7 commit bb8a286
Show file tree
Hide file tree
Showing 24 changed files with 488 additions and 167 deletions.
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,7 @@ py_test(
py_test(
name = "test_models",
tags = ["models"],
size = "small",
size = "medium",
srcs = ["models/tests/test_models.py"]
)

Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/ddpg/tests/test_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,15 +414,15 @@ def test_ddpg_loss_function(self):
trainer.stop()

def _get_batch_helper(self, obs_size, actions, batch_size):
return {
return SampleBatch({
SampleBatch.CUR_OBS: np.random.random(size=obs_size),
SampleBatch.ACTIONS: actions,
SampleBatch.REWARDS: np.random.random(size=(batch_size, )),
SampleBatch.DONES: np.random.choice(
[True, False], size=(batch_size, )),
SampleBatch.NEXT_OBS: np.random.random(size=obs_size),
"weights": np.ones(shape=(batch_size, )),
}
})

def _ddpg_loss_helper(self, train_batch, weights, ks, fw, gamma,
huber_threshold, l2_reg, sess):
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/dqn/tests/test_simple_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_simple_q_loss_function(self):
trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()
# Batch of size=2.
input_ = {
input_ = SampleBatch({
SampleBatch.CUR_OBS: np.random.random(size=(2, 4)),
SampleBatch.ACTIONS: np.array([0, 1]),
SampleBatch.REWARDS: np.array([0.4, -1.23]),
Expand All @@ -94,7 +94,7 @@ def test_simple_q_loss_function(self):
[-0.1, -0.2]]),
SampleBatch.ACTION_PROB: np.array([0.1, 0.2]),
"q_values": np.array([[0.1, 0.2], [0.2, 0.1]]),
}
})
# Get model vars for computing expected model outs (q-vals).
# 0=layer-kernel; 1=layer-bias; 2=q-val-kernel; 3=q-val-bias
vars = policy.get_weights()
Expand Down
3 changes: 2 additions & 1 deletion rllib/agents/marwil/marwil_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def postprocess_advantages(policy,
# input_dict.
# Create an input dict according to the Model's requirements.
index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1
input_dict = policy.model.get_input_dict(sample_batch, index=index)
input_dict = sample_batch.get_single_step_input_dict(
policy.model.view_requirements, index=index)
last_r = policy._value(**input_dict)

# Adds the "advantages" (which in the case of MARWIL are simply the
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/pg/tests/test_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_pg_loss_functions(self):
config["model"]["fcnet_activation"] = "linear"

# Fake CartPole episode of n time steps.
train_batch = {
train_batch = SampleBatch({
SampleBatch.OBS: np.array([[0.1, 0.2, 0.3,
0.4], [0.5, 0.6, 0.7, 0.8],
[0.9, 1.0, 1.1, 1.2]]),
Expand All @@ -79,7 +79,7 @@ def test_pg_loss_functions(self):
SampleBatch.DONES: np.array([False, False, True]),
SampleBatch.EPS_ID: np.array([1234, 1234, 1234]),
SampleBatch.AGENT_INDEX: np.array([0, 0, 0]),
}
})

for fw, sess in framework_iterator(config, session=True):
dist_cls = (Categorical if fw != "torch" else TorchCategorical)
Expand Down
38 changes: 28 additions & 10 deletions rllib/agents/ppo/ppo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,29 @@


def ppo_surrogate_loss(
policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution],
policy: Policy, model: Union[ModelV2, "tf.keras.Model"],
dist_class: Type[TFActionDistribution],
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
"""Constructs the loss for Proximal Policy Objective.
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
model (Union[ModelV2, tf.keras.Model]): The Model to calculate
the loss for.
dist_class (Type[ActionDistribution]: The action distr. class.
train_batch (SampleBatch): The training data.
Returns:
Union[TensorType, List[TensorType]]: A single loss tensor or a list
of loss tensors.
"""
logits, state = model.from_batch(train_batch)
if isinstance(model, tf.keras.Model):
logits, state, extra_outs = model(train_batch)
value_fn_out = extra_outs[SampleBatch.VF_PREDS]
else:
logits, state = model.from_batch(train_batch)
value_fn_out = model.value_function()

curr_action_dist = dist_class(logits, model)

# RNN case: Mask away 0-padded chunks at end of time axis.
Expand Down Expand Up @@ -86,7 +94,6 @@ def reduce_mean_valid(t):

if policy.config["use_gae"]:
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
value_fn_out = model.value_function()
vf_loss1 = tf.math.square(value_fn_out -
train_batch[Postprocessing.VALUE_TARGETS])
vf_clipped = prev_value_fn_out + tf.clip_by_value(
Expand All @@ -112,6 +119,7 @@ def reduce_mean_valid(t):
policy._mean_vf_loss = mean_vf_loss
policy._mean_entropy = mean_entropy
policy._mean_kl = mean_kl
policy._value_fn_out = value_fn_out

return total_loss

Expand All @@ -134,14 +142,14 @@ def kl_and_loss_stats(policy: Policy,
"policy_loss": policy._mean_policy_loss,
"vf_loss": policy._mean_vf_loss,
"vf_explained_var": explained_variance(
train_batch[Postprocessing.VALUE_TARGETS],
policy.model.value_function()),
train_batch[Postprocessing.VALUE_TARGETS], policy._value_fn_out),
"kl": policy._mean_kl,
"entropy": policy._mean_entropy,
"entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64),
}


# TODO: (sven) Deprecate once we only allow native keras models.
def vf_preds_fetches(policy: Policy) -> Dict[str, TensorType]:
"""Defines extra fetches per action computation.
Expand All @@ -152,6 +160,10 @@ def vf_preds_fetches(policy: Policy) -> Dict[str, TensorType]:
Dict[str, TensorType]: Dict with extra tf fetches to perform per
action computation.
"""
# Keras models return values for each call in third return argument
# (dict).
if isinstance(policy.model, tf.keras.Model):
return {}
# Return value function outputs. VF estimates will hence be added to the
# SampleBatches produced by the sampler(s) to generate the train batches
# going into the loss function.
Expand All @@ -177,7 +189,9 @@ def compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer,
tuples.
"""
# Compute the gradients.
variables = policy.model.trainable_variables()
variables = policy.model.trainable_variables
if isinstance(policy.model, ModelV2):
variables = variables()
grads_and_vars = optimizer.compute_gradients(loss, variables)

# Clip by global norm, if necessary.
Expand Down Expand Up @@ -267,9 +281,13 @@ def __init__(self, obs_space, action_space, config):
@make_tf_callable(self.get_session())
def value(**input_dict):
input_dict = SampleBatch(input_dict)
model_out, _ = self.model(input_dict)
# [0] = remove the batch dim.
return self.model.value_function()[0]
if isinstance(self.model, tf.keras.Model):
_, _, extra_outs = self.model(input_dict)
return extra_outs[SampleBatch.VF_PREDS][0]
else:
model_out, _ = self.model(input_dict)
# [0] = remove the batch dim.
return self.model.value_function()[0]

# When not doing GAE, we do not require the value function's output.
else:
Expand Down
8 changes: 6 additions & 2 deletions rllib/agents/ppo/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
check_compute_single_action

# Fake CartPole episode of n time steps.
FAKE_BATCH = {
FAKE_BATCH = SampleBatch({
SampleBatch.OBS: np.array(
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]],
dtype=np.float32),
Expand All @@ -35,7 +35,7 @@
SampleBatch.ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32),
SampleBatch.EPS_ID: np.array([0, 0, 0]),
SampleBatch.AGENT_INDEX: np.array([0, 0, 0]),
}
})


class MyCallbacks(DefaultCallbacks):
Expand Down Expand Up @@ -82,6 +82,9 @@ def test_ppo_compilation_and_lr_schedule(self):
# Settings in case we use an LSTM.
config["model"]["lstm_cell_size"] = 10
config["model"]["max_seq_len"] = 20
# Use default-native keras model whenever possible.
config["model"]["_use_default_native_models"] = True

config["train_batch_size"] = 128
# Test with compression.
config["compress_observations"] = True
Expand All @@ -95,6 +98,7 @@ def test_ppo_compilation_and_lr_schedule(self):
config["model"]["use_lstm"] = lstm
config["model"]["lstm_use_prev_action"] = lstm
config["model"]["lstm_use_prev_reward"] = lstm

trainer = ppo.PPOTrainer(config=config, env=env)
for i in range(num_iterations):
trainer.train()
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/sac/tests/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,15 +448,15 @@ def test_sac_loss_function(self):
check(tf_var, torch_var, rtol=0.1)

def _get_batch_helper(self, obs_size, actions, batch_size):
return {
return SampleBatch({
SampleBatch.CUR_OBS: np.random.random(size=obs_size),
SampleBatch.ACTIONS: actions,
SampleBatch.REWARDS: np.random.random(size=(batch_size, )),
SampleBatch.DONES: np.random.choice(
[True, False], size=(batch_size, )),
SampleBatch.NEXT_OBS: np.random.random(size=obs_size),
"weights": np.random.random(size=(batch_size, )),
}
})

def _sac_loss_helper(self, train_batch, weights, ks, log_alpha, fw, gamma,
sess):
Expand Down
5 changes: 2 additions & 3 deletions rllib/evaluation/collectors/simple_list_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch:

# Due to possible batch-repeats > 1, columns in the resulting batch
# may not all have the same batch size.
batch = SampleBatch(batch_data, _dont_check_lens=True)
batch = SampleBatch(batch_data)

# Add EPS_ID and UNROLL_ID to batch.
batch[SampleBatch.EPS_ID] = np.repeat(self.episode_id, batch.count)
Expand Down Expand Up @@ -366,8 +366,7 @@ def build(self):
this policy.
"""
# Create batch from our buffers.
batch = SampleBatch(
self.buffers, _seq_lens=self.seq_lens, _dont_check_lens=True)
batch = SampleBatch(self.buffers, _seq_lens=self.seq_lens)
# Clear buffers for future samples.
self.buffers.clear()
# Reset agent steps to 0 and seq-lens to empty list.
Expand Down
3 changes: 2 additions & 1 deletion rllib/evaluation/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def compute_gae_for_sample_batch(
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
# Create an input dict according to the Model's requirements.
input_dict = policy.model.get_input_dict(sample_batch, index="last")
input_dict = sample_batch.get_single_step_input_dict(
policy.model.view_requirements, index="last")
last_r = policy._value(**input_dict, seq_lens=input_dict.seq_lens)

# Adds the policy logits, VF preds, and advantages to the batch,
Expand Down
57 changes: 57 additions & 0 deletions rllib/examples/models/modelv3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import numpy as np

from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf, try_import_torch

tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()


class RNNModel(tf.keras.models.Model if tf else object):
"""Example of using the Keras functional API to define an RNN model."""

def __init__(self,
input_space,
action_space,
num_outputs,
*,
name="",
hiddens_size=256,
cell_size=64):
super().__init__(name=name)

self.cell_size = cell_size

# Preprocess observation with a hidden layer and send to LSTM cell
self.dense = tf.keras.layers.Dense(
hiddens_size, activation=tf.nn.relu, name="dense1")
self.lstm = tf.keras.layers.LSTM(
cell_size, return_sequences=True, return_state=True, name="lstm")

# Postprocess LSTM output with another hidden layer and compute
# values.
self.logits = tf.keras.layers.Dense(
num_outputs, activation=tf.keras.activations.linear, name="logits")
self.values = tf.keras.layers.Dense(1, activation=None, name="values")

def call(self, sample_batch):
dense_out = self.dense(sample_batch["obs"])
B = tf.shape(sample_batch.seq_lens)[0]
lstm_in = tf.reshape(dense_out, [B, -1, dense_out.shape.as_list()[1]])
lstm_out, h, c = self.lstm(
inputs=lstm_in,
mask=tf.sequence_mask(sample_batch.seq_lens),
initial_state=[
sample_batch["state_in_0"], sample_batch["state_in_1"]
],
)
lstm_out = tf.reshape(lstm_out, [-1, lstm_out.shape.as_list()[2]])
logits = self.logits(lstm_out)
values = tf.reshape(self.values(lstm_out), [-1])
return logits, [h, c], {SampleBatch.VF_PREDS: values}

def get_initial_state(self):
return [
np.zeros(self.cell_size, np.float32),
np.zeros(self.cell_size, np.float32),
]
Loading

0 comments on commit bb8a286

Please sign in to comment.