Skip to content

Commit

Permalink
[RLlib] Issue 12233 shared tf layers example not really shared (only …
Browse files Browse the repository at this point in the history
…works for tf1.x, not tf2.x). (ray-project#12399)
  • Loading branch information
sven1977 authored Nov 25, 2020
1 parent 95175a8 commit 841d93d
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 6 deletions.
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1899,7 +1899,7 @@ py_test(
tags = ["examples", "examples_M"],
size = "medium",
srcs = ["examples/multi_agent_cartpole.py"],
args = ["--as-test", "--torch", "--stop-reward=70.0", "--num-cpus=4"]
args = ["--as-test", "--framework=torch", "--stop-reward=70.0", "--num-cpus=4"]
)

py_test(
Expand Down
44 changes: 44 additions & 0 deletions rllib/examples/models/shared_weights_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,53 @@
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()

TF2_GLOBAL_SHARED_LAYER = None
if tf:
# The global, shared layer to be used by both models.
TF2_GLOBAL_SHARED_LAYER = tf.keras.layers.Dense(
units=64, activation=tf.nn.relu, name="fc1")


class TF2SharedWeightsModel(TFModelV2):
"""Example of weight sharing between two different TFModelV2s.
NOTE: This will only work for tf2.x. When running with config.framework=tf,
use SharedWeightsModel1 and SharedWeightsModel2 below, instead!
The shared (single) layer is simply defined outside of the two Models,
then used by both Models in their forward pass.
"""

def __init__(self, observation_space, action_space, num_outputs,
model_config, name):
super().__init__(observation_space, action_space, num_outputs,
model_config, name)

inputs = tf.keras.layers.Input(observation_space.shape)
last_layer = TF2_GLOBAL_SHARED_LAYER(inputs)
output = tf.keras.layers.Dense(
units=num_outputs, activation=None, name="fc_out")(last_layer)
vf = tf.keras.layers.Dense(
units=1, activation=None, name="value_out")(last_layer)
self.base_model = tf.keras.models.Model(inputs, [output, vf])
self.register_variables(self.base_model.variables)

@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
out, self._value_out = self.base_model(input_dict["obs"])
return out, []

@override(ModelV2)
def value_function(self):
return tf.reshape(self._value_out, [-1])


class SharedWeightsModel1(TFModelV2):
"""Example of weight sharing between two different TFModelV2s.
NOTE: This will only work for tf1 (static graph). When running with
config.framework=tf2, use TF2SharedWeightsModel, instead!
Here, we share the variables defined in the 'shared' variable scope
by entering it explicitly with tf1.AUTO_REUSE. This creates the
variables for the 'fc1' layer in a global scope called 'shared'
Expand Down Expand Up @@ -85,6 +128,7 @@ def value_function(self):

TORCH_GLOBAL_SHARED_LAYER = None
if torch:
# The global, shared layer to be used by both models.
TORCH_GLOBAL_SHARED_LAYER = SlimFC(
64,
64,
Expand Down
17 changes: 12 additions & 5 deletions rllib/examples/multi_agent_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from ray import tune
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.examples.models.shared_weights_model import \
SharedWeightsModel1, SharedWeightsModel2, TorchSharedWeightsModel
SharedWeightsModel1, SharedWeightsModel2, TF2SharedWeightsModel, \
TorchSharedWeightsModel
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check_learning_achieved
Expand All @@ -35,16 +36,22 @@
parser.add_argument("--simple", action="store_true")
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--as-test", action="store_true")
parser.add_argument("--torch", action="store_true")
parser.add_argument(
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")

if __name__ == "__main__":
args = parser.parse_args()

ray.init(num_cpus=args.num_cpus or None)

# Register the models to use.
mod1 = TorchSharedWeightsModel if args.torch else SharedWeightsModel1
mod2 = TorchSharedWeightsModel if args.torch else SharedWeightsModel2
if args.framework == "torch":
mod1 = mod2 = TorchSharedWeightsModel
elif args.framework in ["tfe", "tf2"]:
mod1 = mod2 = TF2SharedWeightsModel
else:
mod1 = SharedWeightsModel1
mod2 = SharedWeightsModel2
ModelCatalog.register_custom_model("model1", mod1)
ModelCatalog.register_custom_model("model2", mod2)

Expand Down Expand Up @@ -83,7 +90,7 @@ def gen_policy(i):
"policies": policies,
"policy_mapping_fn": (lambda agent_id: random.choice(policy_ids)),
},
"framework": "torch" if args.torch else "tf",
"framework": args.framework,
}
stop = {
"episode_reward_mean": args.stop_reward,
Expand Down

0 comments on commit 841d93d

Please sign in to comment.