Skip to content

Commit

Permalink
Revert "[RLlib] Issue 14533: tf.enable_eager_execution() must be ca…
Browse files Browse the repository at this point in the history
…lled at beginning. (ray-project#14737)" (ray-project#14918)

This reverts commit 3e389d5.
  • Loading branch information
rkooo567 authored Mar 25, 2021
1 parent 493d15e commit fa5f961
Show file tree
Hide file tree
Showing 15 changed files with 54 additions and 113 deletions.
6 changes: 0 additions & 6 deletions rllib/agents/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
from ray.rllib.agents.ddpg.ddpg_tf_policy import DDPGTFPolicy
from ray.rllib.utils.deprecation import DEPRECATED_VALUE

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -170,11 +169,6 @@ def validate_config(config):
"'complete_episodes'. Setting batch_mode=complete_episodes.")
config["batch_mode"] = "complete_episodes"

if config["simple_optimizer"] != DEPRECATED_VALUE or \
config["simple_optimizer"] is False:
logger.warning("`simple_optimizer` must be True (or unset) for DDPG!")
config["simple_optimizer"] = True


def get_policy_class(config):
if config["framework"] == "torch":
Expand Down
3 changes: 1 addition & 2 deletions rllib/agents/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,8 @@ def update_prio(item):
# break e.g. DDPPO!).
td_error = info.get("td_error",
info[LEARNER_STATS_KEY].get("td_error"))
samples.policy_batches[policy_id].set_get_interceptor(None)
prio_dict[policy_id] = (samples.policy_batches[policy_id]
.get("batch_indexes"), td_error)
.data.get("batch_indexes"), td_error)
local_replay_buffer.update_priorities(prio_dict)
return info_dict

Expand Down
15 changes: 7 additions & 8 deletions rllib/agents/ppo/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,8 @@ def _ppo_loss_helper(self,
policy.model)
expected_logp = dist.logp(train_batch[SampleBatch.ACTIONS])
if isinstance(model, TorchModelV2):
train_batch.set_get_interceptor(None)
expected_rho = np.exp(expected_logp.detach().cpu().numpy() -
train_batch[SampleBatch.ACTION_LOGP])
train_batch.get(SampleBatch.ACTION_LOGP))
# KL(prev vs current action dist)-loss component.
kl = np.mean(dist_prev.kl(dist).detach().cpu().numpy())
# Entropy-loss component.
Expand All @@ -365,19 +364,19 @@ def _ppo_loss_helper(self,

# Policy loss component.
pg_loss = np.minimum(
train_batch[Postprocessing.ADVANTAGES] * expected_rho,
train_batch[Postprocessing.ADVANTAGES] * np.clip(
train_batch.get(Postprocessing.ADVANTAGES) * expected_rho,
train_batch.get(Postprocessing.ADVANTAGES) * np.clip(
expected_rho, 1 - policy.config["clip_param"],
1 + policy.config["clip_param"]))

# Value function loss component.
vf_loss1 = np.power(
vf_outs - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
vf_clipped = train_batch[SampleBatch.VF_PREDS] + np.clip(
vf_outs - train_batch[SampleBatch.VF_PREDS],
vf_outs - train_batch.get(Postprocessing.VALUE_TARGETS), 2.0)
vf_clipped = train_batch.get(SampleBatch.VF_PREDS) + np.clip(
vf_outs - train_batch.get(SampleBatch.VF_PREDS),
-policy.config["vf_clip_param"], policy.config["vf_clip_param"])
vf_loss2 = np.power(
vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
vf_clipped - train_batch.get(Postprocessing.VALUE_TARGETS), 2.0)
vf_loss = np.maximum(vf_loss1, vf_loss2)

# Overall loss.
Expand Down
5 changes: 0 additions & 5 deletions rllib/agents/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,6 @@ def validate_config(config: TrainerConfigDict) -> None:
if config["grad_clip"] is not None and config["grad_clip"] <= 0.0:
raise ValueError("`grad_clip` value must be > 0.0!")

if config["simple_optimizer"] != DEPRECATED_VALUE or \
config["simple_optimizer"] is False:
logger.warning("`simple_optimizer` must be True (or unset) for SAC!")
config["simple_optimizer"] = True


def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
"""Policy class picker function. Class is chosen based on DL-framework.
Expand Down
16 changes: 4 additions & 12 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,26 +1162,18 @@ def _validate_config(config: PartialTrainerConfigDict):
if simple_optim_setting != DEPRECATED_VALUE:
deprecation_warning("simple_optimizer", error=False)

framework = config.get("framework")
if config.get("num_gpus", 0) > 1:
if framework in ["tfe", "tf2", "torch"]:
if config.get("framework") in ["tfe", "tf2", "torch"]:
raise ValueError("`num_gpus` > 1 not supported yet for "
"framework={}!".format(framework))
"framework={}!".format(
config.get("framework")))
elif simple_optim_setting is True:
raise ValueError(
"Cannot use `simple_optimizer` if `num_gpus` > 1! "
"Consider `simple_optimizer=False`.")
config["simple_optimizer"] = False
# Auto-setting: Use simple-optimizer for torch/tfe or multiagent,
# otherwise: TFMultiGPU (if supported by the algo's execution plan).
elif simple_optim_setting == DEPRECATED_VALUE:
config["simple_optimizer"] = \
framework != "tf" or len(config["multiagent"]["policies"]) > 0
# User manually set simple-optimizer to False -> Error if not tf.
elif simple_optim_setting is False:
if framework in ["tfe", "tf2", "torch"]:
raise ValueError("`simple_optimizer=False` not supported for "
"framework={}!".format(framework))
config["simple_optimizer"] = True

# Offline RL settings.
if isinstance(config["input_evaluation"], tuple):
Expand Down
1 change: 0 additions & 1 deletion rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,6 @@ def _validate_and_canonicalize(

def _validate_multiagent_config(policy: MultiAgentPolicyConfigDict,
allow_none_graph: bool = False) -> None:
# Loop through all policy definitions in multi-agent policie
for k, v in policy.items():
if not isinstance(k, str):
raise ValueError("policy keys must be strs, got {}".format(
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/custom_keras_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def metrics(self):

if __name__ == "__main__":
args = parser.parse_args()
ray.init(num_cpus=args.num_cpus or None)
ray.init(num_cpus=args.num_cpus or None, local_mode=True)
ModelCatalog.register_custom_model(
"keras_model", MyVisionNetwork
if args.use_vision_network else MyKerasModel)
Expand Down
7 changes: 3 additions & 4 deletions rllib/examples/custom_metrics_and_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def on_postprocess_trajectory(
assert "callback_ok" in trials[0].last_result

# Verify `on_learn_on_batch` custom metrics are there (per policy).
if args.torch:
info_custom_metrics = custom_metrics["default_policy"]
print(info_custom_metrics)
assert "sum_actions_in_train_batch" in info_custom_metrics
info_custom_metrics = custom_metrics["default_policy"]
print(info_custom_metrics)
assert "sum_actions_in_train_batch" in info_custom_metrics
2 changes: 1 addition & 1 deletion rllib/examples/eager_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def compute_penalty(actions, rewards):
)

if __name__ == "__main__":
ray.init()
ray.init(local_mode=True)
args = parser.parse_args()
ModelCatalog.register_custom_model("eager_model", EagerModel)

Expand Down
10 changes: 4 additions & 6 deletions rllib/examples/models/shared_weights_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
torch, nn = try_import_torch()

TF2_GLOBAL_SHARED_LAYER = None
if tf:
# The global, shared layer to be used by both models.
TF2_GLOBAL_SHARED_LAYER = tf.keras.layers.Dense(
units=64, activation=tf.nn.relu, name="fc1")


class TF2SharedWeightsModel(TFModelV2):
Expand All @@ -28,12 +32,6 @@ def __init__(self, observation_space, action_space, num_outputs,
super().__init__(observation_space, action_space, num_outputs,
model_config, name)

global TF2_GLOBAL_SHARED_LAYER
# The global, shared layer to be used by both models.
if TF2_GLOBAL_SHARED_LAYER is None:
TF2_GLOBAL_SHARED_LAYER = tf.keras.layers.Dense(
units=64, activation=tf.nn.relu, name="fc1")

inputs = tf.keras.layers.Input(observation_space.shape)
last_layer = TF2_GLOBAL_SHARED_LAYER(inputs)
output = tf.keras.layers.Dense(
Expand Down
13 changes: 5 additions & 8 deletions rllib/policy/dynamic_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,27 +494,24 @@ def _initialize_loss_from_dummy_batch(
dict(self._input_dict, **self._loss_input_dict))

if self._state_inputs:
train_batch.seq_lens = self._seq_lens
self._loss_input_dict.update({"seq_lens": train_batch.seq_lens})

self._loss_input_dict.update({k: v for k, v in train_batch.items()})
train_batch["seq_lens"] = self._seq_lens

if log_once("loss_init"):
logger.debug(
"Initializing loss function with dummy input:\n\n{}\n".format(
summarize(train_batch)))

self._loss_input_dict.update({k: v for k, v in train_batch.items()})
loss = self._do_loss_init(train_batch)

all_accessed_keys = \
train_batch.accessed_keys | dummy_batch.accessed_keys | \
dummy_batch.added_keys | set(
self.model.view_requirements.keys())

TFPolicy._initialize_loss(self, loss, [
(k, v) for k, v in train_batch.items() if k in all_accessed_keys
] + ([("seq_lens", train_batch.seq_lens)]
if train_batch.seq_lens is not None else []))
TFPolicy._initialize_loss(self, loss, [(k, v)
for k, v in train_batch.items()
if k in all_accessed_keys])

if "is_training" in self._loss_input_dict:
del self._loss_input_dict["is_training"]
Expand Down
6 changes: 2 additions & 4 deletions rllib/policy/eager_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@

def _convert_to_tf(x, dtype=None):
if isinstance(x, SampleBatch):
dict_ = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
if x.seq_lens is not None:
dict_["seq_lens"] = x.seq_lens
return tf.nest.map_structure(_convert_to_tf, dict_)
x = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
return tf.nest.map_structure(_convert_to_tf, x)
elif isinstance(x, Policy):
return x
# Special handling of "Repeated" values.
Expand Down
37 changes: 9 additions & 28 deletions rllib/policy/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,9 @@ def __init__(self, *args, **kwargs):

# Possible seq_lens (TxB or BxT) setup.
self.time_major = kwargs.pop("_time_major", None)
self.seq_lens = kwargs.pop("_seq_lens", kwargs.pop("seq_lens", None))
if self.seq_lens is None and len(args) > 0 and isinstance(
args[0], dict):
self.seq_lens = args[0].pop("_seq_lens", args[0].pop(
"seq_lens", None))
self.seq_lens = kwargs.pop("_seq_lens", None)
if isinstance(self.seq_lens, list):
self.seq_lens = np.array(self.seq_lens, dtype=np.int32)
self.seq_lens = np.array(self.seq_lens)
self.dont_check_lens = kwargs.pop("_dont_check_lens", False)
self.max_seq_len = kwargs.pop("_max_seq_len", None)
if self.max_seq_len is None and self.seq_lens is not None and \
Expand All @@ -83,16 +79,18 @@ def __init__(self, *args, **kwargs):
# by column name (str) via e.g. self["some-col"].
dict.__init__(self, *args, **kwargs)

if self.is_training is None:
self.is_training = self.pop("is_training", False)
if self.seq_lens is None:
self.seq_lens = self.get("seq_lens", None)

self.accessed_keys = set()
self.added_keys = set()
self.deleted_keys = set()
self.intercepted_values = {}

self.get_interceptor = None

if self.is_training is None:
self.is_training = self.pop("is_training", False)

lengths = []
copy_ = {k: v for k, v in self.items()}
for k, v in copy_.items():
Expand Down Expand Up @@ -430,12 +428,6 @@ def size_bytes(self) -> int:
"""
return sum(sys.getsizeof(d) for d in self.values())

def get(self, key, default=None):
try:
return self.__getitem__(key)
except KeyError:
return default

@PublicAPI
def __getitem__(self, key: str) -> TensorType:
"""Returns one column (by key) from the data.
Expand All @@ -446,8 +438,6 @@ def __getitem__(self, key: str) -> TensorType:
Returns:
TensorType: The data under the given key.
"""
self.accessed_keys.add(key)

# Backward compatibility for when "input-dicts" were used.
if key == "is_training":
if log_once("SampleBatch['is_training']"):
Expand All @@ -456,14 +446,8 @@ def __getitem__(self, key: str) -> TensorType:
new="SampleBatch.is_training",
error=False)
return self.is_training
elif key == "seq_lens":
if self.get_interceptor is not None and self.seq_lens is not None:
if "seq_lens" not in self.intercepted_values:
self.intercepted_values["seq_lens"] = self.get_interceptor(
self.seq_lens)
return self.intercepted_values["seq_lens"]
return self.seq_lens

self.accessed_keys.add(key)
value = dict.__getitem__(self, key)
if self.get_interceptor is not None:
if key not in self.intercepted_values:
Expand All @@ -479,12 +463,9 @@ def __setitem__(self, key, item) -> None:
key (str): The column name to set a value for.
item (TensorType): The data to insert.
"""
if key == "seq_lens":
self.seq_lens = item
return
# Defend against creating SampleBatch via pickle (no property
# `added_keys` and first item is already set).
elif not hasattr(self, "added_keys"):
if not hasattr(self, "added_keys"):
dict.__setitem__(self, key, item)
return

Expand Down
25 changes: 6 additions & 19 deletions rllib/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,43 +110,30 @@ def valid_tf_checkpoint(checkpoint_dir):
assert model

shutil.rmtree(export_dir)
algo.stop()


class TestExport(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4)
ray.init(
num_cpus=10, object_store_memory=1e9, ignore_reinit_error=True)

@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()

def test_export_a3c(self):
failures = []
export_test("A3C", failures, "tf")
assert not failures, failures

def test_export_ddpg(self):
failures = []
export_test("DDPG", failures, "tf")
assert not failures, failures

def test_export_dqn(self):
failures = []
export_test("DQN", failures, "tf")
assert not failures, failures

def test_export_ppo(self):
failures = []
export_test("PPO", failures, "torch")
export_test("PPO", failures, "tf")
assert not failures, failures

def test_export_sac(self):
def test_export(self):
failures = []
export_test("SAC", failures, "tf")
for name in ["A3C", "DQN", "DDPG", "SAC"]:
export_test(name, failures)
assert not failures, failures
print("All export tests passed!")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit fa5f961

Please sign in to comment.