Skip to content

Commit

Permalink
use policy.as_eager() to convert to eager template
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl committed Aug 12, 2019
1 parent 38ea9a9 commit 87d1999
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 44 deletions.
26 changes: 0 additions & 26 deletions rllib/agents/pg/eager_pg_policy.py

This file was deleted.

7 changes: 0 additions & 7 deletions rllib/agents/pg/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.pg.pg_policy import PGTFPolicy
from ray.rllib.agents.pg.eager_pg_policy import PGTFPolicy as EagerPGTFPolicy
from ray.rllib.utils import try_import_tf

# yapf: disable
Expand All @@ -17,8 +16,6 @@
"lr": 0.0004,
# Use PyTorch as backend
"use_pytorch": False,
# Use TF eager:
"use_eager": False,
})
# __sphinx_doc_end__
# yapf: enable
Expand All @@ -32,10 +29,6 @@ def get_policy_class(config):
if config["use_pytorch"]:
from ray.rllib.agents.pg.torch_pg_policy import PGTorchPolicy
return PGTorchPolicy
elif config["use_eager"]:
tf = try_import_tf()
tf.enable_eager_execution()
return EagerPGTFPolicy
else:
return PGTFPolicy

Expand Down
2 changes: 2 additions & 0 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
"ignore_worker_failures": False,
# Log system resource metrics to results.
"log_sys_usage": True,
# Run policy using TF eager if possible
"use_eager": False,

# === Policy ===
# Arguments to pass to model. See models/catalog.py for a full list of the
Expand Down
2 changes: 2 additions & 0 deletions rllib/agents/trainer_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def _init(self, config, env_creator):
policy = default_policy
else:
policy = get_policy_class(config)
if config["use_eager"]:
policy = policy.as_eager()
if before_init:
before_init(self)
if make_workers:
Expand Down
39 changes: 28 additions & 11 deletions rllib/policy/eager_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,20 @@ def export_checkpoint(self, export_dir):
return NotImplemented


def build_tf_policy(name, postprocess_fn, loss_fn,
make_optimizer, get_default_config=None, stats_fn=None):
catalog.ModelCatalog.register_custom_model("keras_model",
fcnet_v2.FullyConnectedNetwork)
def build_tf_policy(name, loss_fn,
get_default_config=None,
postprocess_fn=None,
stats_fn=None,
optimizer_fn=None,
obs_include_prev_action_reward=True):

class EagerPolicy(TFEagerPolicy):
def __init__(self, observation_space, action_space, config):
assert tf.executing_eagerly()

if get_default_config:
config = dict(get_default_config(), **config)
self.config = config

self.dist_class, logit_dim = catalog.ModelCatalog.get_action_dist(
action_space, config["model"])
Expand All @@ -171,14 +177,19 @@ def __init__(self, observation_space, action_space, config):
config["model"],
framework="tf",
)
optimizer = make_optimizer(self, observation_space, action_space,
config)
self.config = config

if optimizer_fn:
optimizer = optimizer_fn(self, observation_space, action_space,
config)
else:
optimizer = tf.train.AdamOptimizer(config["lr"])

TFEagerPolicy.__init__(self, optimizer, model, observation_space,
action_space)

def postprocess_trajectory(self, samples, other_agent_batches=None,
episode=None):
assert tf.executing_eagerly()
return postprocess_fn(self, samples)

def compute_actions(self,
Expand All @@ -189,25 +200,31 @@ def compute_actions(self,
info_batch=None,
episodes=None,
**kwargs):
assert tf.executing_eagerly()
seq_len = tf.ones(len(obs_batch))
input_dict = {
SampleBatch.CUR_OBS: tf.convert_to_tensor(
obs_batch, dtype=tf.float32),
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
prev_action_batch, dtype=tf.int32),
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
prev_reward_batch, dtype=tf.float32),
"is_training": tf.convert_to_tensor(True),
}
if obs_include_prev_action_reward:
input_dict.update({
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
prev_action_batch, dtype=tf.int32),
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
prev_reward_batch, dtype=tf.float32),
})
model_out, states = self.model(input_dict, state_batches, seq_len)

actions_dist = self.dist_class(model_out)
return actions_dist.sample().numpy(), states, {}

def loss(self, outputs, samples):
assert tf.executing_eagerly()
return loss_fn(outputs, samples)

def stats(self, outputs, samples):
assert tf.executing_eagerly()
if stats_fn is None:
return {}
return stats_fn(outputs, samples)
Expand Down
10 changes: 10 additions & 0 deletions rllib/policy/tf_policy_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import print_function

from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
from ray.rllib.policy import eager_policy
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.utils import add_mixins
Expand Down Expand Up @@ -211,7 +212,16 @@ def extra_compute_grad_feed_dict(self):
def with_updates(**overrides):
return build_tf_policy(**dict(original_kwargs, **overrides))

@staticmethod
def as_eager():
minimal = {
k: v for k, v in original_kwargs.items()
if v is not None
}
return eager_policy.build_tf_policy(**minimal)

policy_cls.with_updates = with_updates
policy_cls.as_eager = as_eager
policy_cls.__name__ = name
policy_cls.__qualname__ = name
return policy_cls

0 comments on commit 87d1999

Please sign in to comment.