Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Redo issue 14533 tf enable eager exec #14984

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
wip.
  • Loading branch information
sven1977 committed Mar 29, 2021
commit 8ba7bbdb8c9ba32b1e2c065146aabdcc40532b02
6 changes: 6 additions & 0 deletions rllib/agents/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
from ray.rllib.agents.ddpg.ddpg_tf_policy import DDPGTFPolicy
from ray.rllib.utils.deprecation import DEPRECATED_VALUE

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -169,6 +170,11 @@ def validate_config(config):
"'complete_episodes'. Setting batch_mode=complete_episodes.")
config["batch_mode"] = "complete_episodes"

if config["simple_optimizer"] != DEPRECATED_VALUE or \
config["simple_optimizer"] is False:
logger.warning("`simple_optimizer` must be True (or unset) for DDPG!")
config["simple_optimizer"] = True


def get_policy_class(config):
if config["framework"] == "torch":
Expand Down
3 changes: 2 additions & 1 deletion rllib/agents/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,9 @@ def update_prio(item):
# break e.g. DDPPO!).
td_error = info.get("td_error",
info[LEARNER_STATS_KEY].get("td_error"))
samples.policy_batches[policy_id].set_get_interceptor(None)
prio_dict[policy_id] = (samples.policy_batches[policy_id]
.data.get("batch_indexes"), td_error)
.get("batch_indexes"), td_error)
local_replay_buffer.update_priorities(prio_dict)
return info_dict

Expand Down
14 changes: 7 additions & 7 deletions rllib/agents/ppo/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _ppo_loss_helper(self,
expected_logp = dist.logp(train_batch[SampleBatch.ACTIONS])
if isinstance(model, TorchModelV2):
expected_rho = np.exp(expected_logp.detach().cpu().numpy() -
train_batch.get(SampleBatch.ACTION_LOGP))
train_batch[SampleBatch.ACTION_LOGP])
# KL(prev vs current action dist)-loss component.
kl = np.mean(dist_prev.kl(dist).detach().cpu().numpy())
# Entropy-loss component.
Expand All @@ -364,19 +364,19 @@ def _ppo_loss_helper(self,

# Policy loss component.
pg_loss = np.minimum(
train_batch.get(Postprocessing.ADVANTAGES) * expected_rho,
train_batch.get(Postprocessing.ADVANTAGES) * np.clip(
train_batch[Postprocessing.ADVANTAGES] * expected_rho,
train_batch[Postprocessing.ADVANTAGES] * np.clip(
expected_rho, 1 - policy.config["clip_param"],
1 + policy.config["clip_param"]))

# Value function loss component.
vf_loss1 = np.power(
vf_outs - train_batch.get(Postprocessing.VALUE_TARGETS), 2.0)
vf_clipped = train_batch.get(SampleBatch.VF_PREDS) + np.clip(
vf_outs - train_batch.get(SampleBatch.VF_PREDS),
vf_outs - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
vf_clipped = train_batch[SampleBatch.VF_PREDS] + np.clip(
vf_outs - train_batch[SampleBatch.VF_PREDS],
-policy.config["vf_clip_param"], policy.config["vf_clip_param"])
vf_loss2 = np.power(
vf_clipped - train_batch.get(Postprocessing.VALUE_TARGETS), 2.0)
vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
vf_loss = np.maximum(vf_loss1, vf_loss2)

# Overall loss.
Expand Down
5 changes: 5 additions & 0 deletions rllib/agents/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ def validate_config(config: TrainerConfigDict) -> None:
if config["grad_clip"] is not None and config["grad_clip"] <= 0.0:
raise ValueError("`grad_clip` value must be > 0.0!")

if config["simple_optimizer"] != DEPRECATED_VALUE or \
config["simple_optimizer"] is False:
logger.warning("`simple_optimizer` must be True (or unset) for SAC!")
config["simple_optimizer"] = True


def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
"""Policy class picker function. Class is chosen based on DL-framework.
Expand Down
17 changes: 13 additions & 4 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,18 +1162,27 @@ def _validate_config(config: PartialTrainerConfigDict):
if simple_optim_setting != DEPRECATED_VALUE:
deprecation_warning("simple_optimizer", error=False)

framework = config.get("framework")
if config.get("num_gpus", 0) > 1:
if config.get("framework") in ["tfe", "tf2", "torch"]:
if framework in ["tfe", "tf2", "torch"]:
raise ValueError("`num_gpus` > 1 not supported yet for "
"framework={}!".format(
config.get("framework")))
"framework={}!".format(framework))
elif simple_optim_setting is True:
raise ValueError(
"Cannot use `simple_optimizer` if `num_gpus` > 1! "
"Consider `simple_optimizer=False`.")
config["simple_optimizer"] = False
# Auto-setting: Use simple-optimizer for torch/tfe or multiagent,
# otherwise: TFMultiGPU (if supported by the algo's execution plan).
elif simple_optim_setting == DEPRECATED_VALUE:
config["simple_optimizer"] = True
config["simple_optimizer"] = \
framework != "tf" or len(config["multiagent"]["policies"]) > 0
# User manually set simple-optimizer to False -> Error if not tf.
elif simple_optim_setting is False:
if framework in ["tfe", "tf2", "torch"]:
raise ValueError("`simple_optimizer=False` not supported for "
"framework={}!".format(framework))


# Offline RL settings.
if isinstance(config["input_evaluation"], tuple):
Expand Down
1 change: 1 addition & 0 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,7 @@ def _validate_and_canonicalize(

def _validate_multiagent_config(policy: MultiAgentPolicyConfigDict,
allow_none_graph: bool = False) -> None:
# Loop through all policy definitions in multi-agent policie
for k, v in policy.items():
if not isinstance(k, str):
raise ValueError("policy keys must be strs, got {}".format(
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/custom_keras_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def metrics(self):

if __name__ == "__main__":
args = parser.parse_args()
ray.init(num_cpus=args.num_cpus or None, local_mode=True)
ray.init(num_cpus=args.num_cpus or None)
ModelCatalog.register_custom_model(
"keras_model", MyVisionNetwork
if args.use_vision_network else MyKerasModel)
Expand Down
7 changes: 4 additions & 3 deletions rllib/examples/custom_metrics_and_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def on_postprocess_trajectory(
assert "callback_ok" in trials[0].last_result

# Verify `on_learn_on_batch` custom metrics are there (per policy).
info_custom_metrics = custom_metrics["default_policy"]
print(info_custom_metrics)
assert "sum_actions_in_train_batch" in info_custom_metrics
if args.torch:
info_custom_metrics = custom_metrics["default_policy"]
print(info_custom_metrics)
assert "sum_actions_in_train_batch" in info_custom_metrics
2 changes: 1 addition & 1 deletion rllib/examples/eager_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def compute_penalty(actions, rewards):
)

if __name__ == "__main__":
ray.init(local_mode=True)
ray.init()
args = parser.parse_args()
ModelCatalog.register_custom_model("eager_model", EagerModel)

Expand Down
10 changes: 6 additions & 4 deletions rllib/examples/models/shared_weights_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
torch, nn = try_import_torch()

TF2_GLOBAL_SHARED_LAYER = None
if tf:
# The global, shared layer to be used by both models.
TF2_GLOBAL_SHARED_LAYER = tf.keras.layers.Dense(
units=64, activation=tf.nn.relu, name="fc1")


class TF2SharedWeightsModel(TFModelV2):
Expand All @@ -32,6 +28,12 @@ def __init__(self, observation_space, action_space, num_outputs,
super().__init__(observation_space, action_space, num_outputs,
model_config, name)

global TF2_GLOBAL_SHARED_LAYER
# The global, shared layer to be used by both models.
if TF2_GLOBAL_SHARED_LAYER is None:
TF2_GLOBAL_SHARED_LAYER = tf.keras.layers.Dense(
units=64, activation=tf.nn.relu, name="fc1")

inputs = tf.keras.layers.Input(observation_space.shape)
last_layer = TF2_GLOBAL_SHARED_LAYER(inputs)
output = tf.keras.layers.Dense(
Expand Down
13 changes: 8 additions & 5 deletions rllib/policy/dynamic_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,24 +494,27 @@ def _initialize_loss_from_dummy_batch(
dict(self._input_dict, **self._loss_input_dict))

if self._state_inputs:
train_batch["seq_lens"] = self._seq_lens
train_batch.seq_lens = self._seq_lens
self._loss_input_dict.update({"seq_lens": train_batch.seq_lens})

self._loss_input_dict.update({k: v for k, v in train_batch.items()})

if log_once("loss_init"):
logger.debug(
"Initializing loss function with dummy input:\n\n{}\n".format(
summarize(train_batch)))

self._loss_input_dict.update({k: v for k, v in train_batch.items()})
loss = self._do_loss_init(train_batch)

all_accessed_keys = \
train_batch.accessed_keys | dummy_batch.accessed_keys | \
dummy_batch.added_keys | set(
self.model.view_requirements.keys())

TFPolicy._initialize_loss(self, loss, [(k, v)
for k, v in train_batch.items()
if k in all_accessed_keys])
TFPolicy._initialize_loss(self, loss, [
(k, v) for k, v in train_batch.items() if k in all_accessed_keys
] + ([("seq_lens", train_batch.seq_lens)]
if train_batch.seq_lens is not None else []))

if "is_training" in self._loss_input_dict:
del self._loss_input_dict["is_training"]
Expand Down
6 changes: 4 additions & 2 deletions rllib/policy/eager_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@

def _convert_to_tf(x, dtype=None):
if isinstance(x, SampleBatch):
x = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
return tf.nest.map_structure(_convert_to_tf, x)
dict_ = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
if x.seq_lens is not None:
dict_["seq_lens"] = x.seq_lens
return tf.nest.map_structure(_convert_to_tf, dict_)
elif isinstance(x, Policy):
return x
# Special handling of "Repeated" values.
Expand Down
39 changes: 30 additions & 9 deletions rllib/policy/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,15 @@ def __init__(self, *args, **kwargs):

# Possible seq_lens (TxB or BxT) setup.
self.time_major = kwargs.pop("_time_major", None)
self.seq_lens = kwargs.pop("_seq_lens", None)

self.seq_lens = kwargs.pop("_seq_lens", kwargs.pop("seq_lens", None))
if self.seq_lens is None and len(args) > 0 and isinstance(
args[0], dict):
self.seq_lens = args[0].pop("_seq_lens", args[0].pop(
"seq_lens", None))
if isinstance(self.seq_lens, list):
self.seq_lens = np.array(self.seq_lens)
self.seq_lens = np.array(self.seq_lens, dtype=np.int32)

self.dont_check_lens = kwargs.pop("_dont_check_lens", False)
self.max_seq_len = kwargs.pop("_max_seq_len", None)
if self.max_seq_len is None and self.seq_lens is not None and \
Expand All @@ -79,18 +85,16 @@ def __init__(self, *args, **kwargs):
# by column name (str) via e.g. self["some-col"].
dict.__init__(self, *args, **kwargs)

if self.is_training is None:
self.is_training = self.pop("is_training", False)
if self.seq_lens is None:
self.seq_lens = self.get("seq_lens", None)

self.accessed_keys = set()
self.added_keys = set()
self.deleted_keys = set()
self.intercepted_values = {}

self.get_interceptor = None

if self.is_training is None:
self.is_training = self.pop("is_training", False)

lengths = []
copy_ = {k: v for k, v in self.items()}
for k, v in copy_.items():
Expand Down Expand Up @@ -428,6 +432,12 @@ def size_bytes(self) -> int:
"""
return sum(sys.getsizeof(d) for d in self.values())

def get(self, key, default=None):
try:
return self.__getitem__(key)
except KeyError:
return default

@PublicAPI
def __getitem__(self, key: str) -> TensorType:
"""Returns one column (by key) from the data.
Expand All @@ -438,6 +448,8 @@ def __getitem__(self, key: str) -> TensorType:
Returns:
TensorType: The data under the given key.
"""
self.accessed_keys.add(key)

# Backward compatibility for when "input-dicts" were used.
if key == "is_training":
if log_once("SampleBatch['is_training']"):
Expand All @@ -446,8 +458,14 @@ def __getitem__(self, key: str) -> TensorType:
new="SampleBatch.is_training",
error=False)
return self.is_training
elif key == "seq_lens":
if self.get_interceptor is not None and self.seq_lens is not None:
if "seq_lens" not in self.intercepted_values:
self.intercepted_values["seq_lens"] = self.get_interceptor(
self.seq_lens)
return self.intercepted_values["seq_lens"]
return self.seq_lens

self.accessed_keys.add(key)
value = dict.__getitem__(self, key)
if self.get_interceptor is not None:
if key not in self.intercepted_values:
Expand All @@ -463,9 +481,12 @@ def __setitem__(self, key, item) -> None:
key (str): The column name to set a value for.
item (TensorType): The data to insert.
"""
if key == "seq_lens":
self.seq_lens = item
return
# Defend against creating SampleBatch via pickle (no property
# `added_keys` and first item is already set).
if not hasattr(self, "added_keys"):
elif not hasattr(self, "added_keys"):
dict.__setitem__(self, key, item)
return

Expand Down
24 changes: 19 additions & 5 deletions rllib/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,28 +110,42 @@ def valid_tf_checkpoint(checkpoint_dir):
assert model

shutil.rmtree(export_dir)
algo.stop()


class TestExport(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(
num_cpus=10, object_store_memory=1e9, ignore_reinit_error=True)
ray.init(num_cpus=4)

@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()

def test_export_a3c(self):
failures = []
export_test("A3C", failures, "tf")
assert not failures, failures

def test_export_ddpg(self):
failures = []
export_test("DDPG", failures, "tf")
assert not failures, failures

def test_export_dqn(self):
failures = []
export_test("DQN", failures, "tf")
assert not failures, failures

def test_export_ppo(self):
failures = []
export_test("PPO", failures, "torch")
export_test("PPO", failures, "tf")
assert not failures, failures

def test_export(self):
def test_export_sac(self):
failures = []
for name in ["A3C", "DQN", "DDPG", "SAC"]:
export_test(name, failures)
export_test("SAC", failures, "tf")
assert not failures, failures
print("All export tests passed!")

Expand Down
17 changes: 8 additions & 9 deletions rllib/tests/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def test_simple_optimizer_sequencing(self):
"train_batch_size": 10,
"sgd_minibatch_size": 10,
"num_sgd_iter": 1,
"simple_optimizer": True,
"model": {
"custom_model": "rnn",
"max_seq_len": 4,
Expand Down Expand Up @@ -196,8 +197,8 @@ def test_minibatch_sequencing(self):
ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_1"))
if batch0["sequences"][0][0][0] > batch1["sequences"][0][0][0]:
batch0, batch1 = batch1, batch0 # sort minibatches
self.assertEqual(batch0["seq_lens"].tolist(), [4, 4, 2])
self.assertEqual(batch1["seq_lens"].tolist(), [4, 3, 3])
self.assertEqual(batch0["seq_lens"].tolist(), [4, 4])
self.assertEqual(batch1["seq_lens"].tolist(), [4, 3])
self.assertEqual(batch0["sequences"].tolist(), [
[[0], [1], [2], [3]],
[[4], [5], [6], [7]],
Expand All @@ -217,17 +218,15 @@ def test_minibatch_sequencing(self):
ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_3"))
if batch2["sequences"][0][0][0] > batch3["sequences"][0][0][0]:
batch2, batch3 = batch3, batch2
self.assertEqual(batch2["seq_lens"].tolist(), [4, 4, 2])
self.assertEqual(batch3["seq_lens"].tolist(), [4, 4, 2])
self.assertEqual(batch2["seq_lens"].tolist(), [4, 4])
self.assertEqual(batch3["seq_lens"].tolist(), [2, 4])
self.assertEqual(batch2["sequences"].tolist(), [
[[0], [1], [2], [3]],
[[4], [5], [6], [7]],
[[8], [9], [0], [0]],
])
self.assertEqual(batch3["sequences"].tolist(), [
[[5], [6], [7], [8]],
[[9], [10], [11], [12]],
])
self.assertEqual(batch3["sequences"].tolist(), [
[[13], [14], [0], [0]],
[[0], [1], [2], [3]],
])


Expand Down