forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathshared_weights_model.py
205 lines (168 loc) · 6.79 KB
/
shared_weights_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import numpy as np
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
TF2_GLOBAL_SHARED_LAYER = None
class TF2SharedWeightsModel(TFModelV2):
"""Example of weight sharing between two different TFModelV2s.
NOTE: This will only work for tf2.x. When running with config.framework=tf,
use SharedWeightsModel1 and SharedWeightsModel2 below, instead!
The shared (single) layer is simply defined outside of the two Models,
then used by both Models in their forward pass.
"""
def __init__(
self, observation_space, action_space, num_outputs, model_config, name
):
super().__init__(
observation_space, action_space, num_outputs, model_config, name
)
global TF2_GLOBAL_SHARED_LAYER
# The global, shared layer to be used by both models.
if TF2_GLOBAL_SHARED_LAYER is None:
TF2_GLOBAL_SHARED_LAYER = tf.keras.layers.Dense(
units=64, activation=tf.nn.relu, name="fc1"
)
inputs = tf.keras.layers.Input(observation_space.shape)
last_layer = TF2_GLOBAL_SHARED_LAYER(inputs)
output = tf.keras.layers.Dense(
units=num_outputs, activation=None, name="fc_out"
)(last_layer)
vf = tf.keras.layers.Dense(units=1, activation=None, name="value_out")(
last_layer
)
self.base_model = tf.keras.models.Model(inputs, [output, vf])
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
out, self._value_out = self.base_model(input_dict["obs"])
return out, []
@override(ModelV2)
def value_function(self):
return tf.reshape(self._value_out, [-1])
class SharedWeightsModel1(TFModelV2):
"""Example of weight sharing between two different TFModelV2s.
NOTE: This will only work for tf1 (static graph). When running with
config.framework_str=tf2, use TF2SharedWeightsModel, instead!
Here, we share the variables defined in the 'shared' variable scope
by entering it explicitly with tf1.AUTO_REUSE. This creates the
variables for the 'fc1' layer in a global scope called 'shared'
(outside of the Policy's normal variable scope).
"""
def __init__(
self, observation_space, action_space, num_outputs, model_config, name
):
super().__init__(
observation_space, action_space, num_outputs, model_config, name
)
inputs = tf.keras.layers.Input(observation_space.shape)
with tf1.variable_scope(
tf1.VariableScope(tf1.AUTO_REUSE, "shared"),
reuse=tf1.AUTO_REUSE,
auxiliary_name_scope=False,
):
last_layer = tf.keras.layers.Dense(
units=64, activation=tf.nn.relu, name="fc1"
)(inputs)
output = tf.keras.layers.Dense(
units=num_outputs, activation=None, name="fc_out"
)(last_layer)
vf = tf.keras.layers.Dense(units=1, activation=None, name="value_out")(
last_layer
)
self.base_model = tf.keras.models.Model(inputs, [output, vf])
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
out, self._value_out = self.base_model(input_dict["obs"])
return out, []
@override(ModelV2)
def value_function(self):
return tf.reshape(self._value_out, [-1])
class SharedWeightsModel2(TFModelV2):
"""The "other" TFModelV2 using the same shared space as the one above."""
def __init__(
self, observation_space, action_space, num_outputs, model_config, name
):
super().__init__(
observation_space, action_space, num_outputs, model_config, name
)
inputs = tf.keras.layers.Input(observation_space.shape)
# Weights shared with SharedWeightsModel1.
with tf1.variable_scope(
tf1.VariableScope(tf1.AUTO_REUSE, "shared"),
reuse=tf1.AUTO_REUSE,
auxiliary_name_scope=False,
):
last_layer = tf.keras.layers.Dense(
units=64, activation=tf.nn.relu, name="fc1"
)(inputs)
output = tf.keras.layers.Dense(
units=num_outputs, activation=None, name="fc_out"
)(last_layer)
vf = tf.keras.layers.Dense(units=1, activation=None, name="value_out")(
last_layer
)
self.base_model = tf.keras.models.Model(inputs, [output, vf])
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
out, self._value_out = self.base_model(input_dict["obs"])
return out, []
@override(ModelV2)
def value_function(self):
return tf.reshape(self._value_out, [-1])
TORCH_GLOBAL_SHARED_LAYER = None
if torch:
# The global, shared layer to be used by both models.
TORCH_GLOBAL_SHARED_LAYER = SlimFC(
64,
64,
activation_fn=nn.ReLU,
initializer=torch.nn.init.xavier_uniform_,
)
class TorchSharedWeightsModel(TorchModelV2, nn.Module):
"""Example of weight sharing between two different TorchModelV2s.
The shared (single) layer is simply defined outside of the two Models,
then used by both Models in their forward pass.
"""
def __init__(
self, observation_space, action_space, num_outputs, model_config, name
):
TorchModelV2.__init__(
self, observation_space, action_space, num_outputs, model_config, name
)
nn.Module.__init__(self)
# Non-shared initial layer.
self.first_layer = SlimFC(
int(np.product(observation_space.shape)),
64,
activation_fn=nn.ReLU,
initializer=torch.nn.init.xavier_uniform_,
)
# Non-shared final layer.
self.last_layer = SlimFC(
64,
self.num_outputs,
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_,
)
self.vf = SlimFC(
64,
1,
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_,
)
self._global_shared_layer = TORCH_GLOBAL_SHARED_LAYER
self._output = None
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
out = self.first_layer(input_dict["obs"])
self._output = self._global_shared_layer(out)
model_out = self.last_layer(self._output)
return model_out, []
@override(ModelV2)
def value_function(self):
assert self._output is not None, "must call forward first!"
return torch.reshape(self.vf(self._output), [-1])