Skip to content

Commit f2d00a4

Browse files
Add partitioner that partitions a variable such that no partition gets less than the given minimum size of chunk.
Swith from 'layers.legacy_fully_connected' to 'layers.fully_connected' in _DNNLinearCombinedBaseEstimator. Also, partition the weights of the fully connected layers using the added partitioner. Change: 125492475
1 parent 95237d1 commit f2d00a4

File tree

4 files changed

+183
-18
lines changed

4 files changed

+183
-18
lines changed

tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141
from tensorflow.python.ops import math_ops
4242
from tensorflow.python.ops import nn
4343
from tensorflow.python.ops import parsing_ops
44+
from tensorflow.python.ops import partitioned_variables
4445
from tensorflow.python.ops import state_ops
46+
from tensorflow.python.ops import variable_scope
4547
from tensorflow.python.ops import variables
4648
from tensorflow.python.training import training
4749

@@ -247,24 +249,32 @@ def _dnn_logits(self, features, is_training=False):
247249
self._get_dnn_feature_columns(),
248250
weight_collections=[self._dnn_weight_collection])
249251
for layer_id, num_hidden_units in enumerate(self._dnn_hidden_units):
250-
net = layers.legacy_fully_connected(
251-
net,
252-
num_hidden_units,
253-
activation_fn=self._dnn_activation_fn,
254-
weight_collections=[self._dnn_weight_collection],
255-
bias_collections=[self._dnn_weight_collection],
256-
name="hiddenlayer_%d" % layer_id)
257-
if self._dnn_dropout is not None and is_training:
258-
net = layers.dropout(
252+
op_scope = "hiddenlayer_%d" % layer_id
253+
with variable_scope.variable_op_scope(
254+
[net], op_scope,
255+
partitioner=partitioned_variables.min_max_variable_partitioner(
256+
max_partitions=self._config.num_ps_replicas)):
257+
net = layers.fully_connected(
259258
net,
260-
keep_prob=(1.0 - self._dnn_dropout))
261-
self._add_hidden_layer_summary(net, "hiddenlayer_%d" % layer_id)
262-
logit = layers.legacy_fully_connected(
263-
net,
264-
self._num_label_columns(),
265-
weight_collections=[self._dnn_weight_collection],
266-
bias_collections=[self._dnn_weight_collection],
267-
name="dnn_logit")
259+
num_hidden_units,
260+
activation_fn=self._dnn_activation_fn,
261+
variables_collections=[self._dnn_weight_collection],
262+
scope=op_scope)
263+
if self._dnn_dropout is not None and is_training:
264+
net = layers.dropout(
265+
net,
266+
keep_prob=(1.0 - self._dnn_dropout))
267+
self._add_hidden_layer_summary(net, op_scope)
268+
with variable_scope.variable_op_scope(
269+
[net], "dnn_logit",
270+
partitioner=partitioned_variables.min_max_variable_partitioner(
271+
max_partitions=self._config.num_ps_replicas)):
272+
logit = layers.fully_connected(
273+
net,
274+
self._num_label_columns(),
275+
activation_fn=None,
276+
variables_collections=[self._dnn_weight_collection],
277+
scope="dnn_logit")
268278
self._add_hidden_layer_summary(logit, "dnn_logit")
269279
return logit
270280

tensorflow/python/kernel_tests/partitioned_variables_test.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,93 @@ def testVariableAxisSizePartitioner(self):
131131
self.assertEqual(len(v3str_list), 4)
132132
self.assertAllEqual(v3str_part, (1, 1, 1, 4))
133133

134+
def _testMinMaxVariablePartitioner(self, max_partitions, axis, min_slice_size,
135+
var_name, var_shape,
136+
expected_axis_shards, expected_partitions):
137+
partitioner = tf.min_max_variable_partitioner(max_partitions=max_partitions,
138+
axis=axis,
139+
min_slice_size=min_slice_size)
140+
with tf.variable_scope("root", partitioner=partitioner):
141+
v0 = tf.get_variable(var_name, dtype=tf.float32, shape=var_shape)
142+
v0_list = v0._get_variable_list()
143+
v0_part = v0._get_partitions()
144+
self.assertEqual(len(v0_list), expected_axis_shards)
145+
self.assertAllEqual(v0_part, expected_partitions)
146+
147+
def testMinMaxVariablePartitioner(self):
148+
with self.test_session():
149+
# Partitioning a variable of shape=[2048] with a minimum of 2K per slice.
150+
self._testMinMaxVariablePartitioner(max_partitions=100, axis=0,
151+
min_slice_size=2 << 10,
152+
var_name="v0_0", var_shape=[2048],
153+
expected_axis_shards=4,
154+
expected_partitions=[4])
155+
156+
# Partitioning a variable of shape=[2048, 1024] with a minimum of 256K per
157+
# slice.
158+
self._testMinMaxVariablePartitioner(max_partitions=100, axis=0,
159+
min_slice_size=256 << 10,
160+
var_name="v0", var_shape=[2048, 1024],
161+
expected_axis_shards=32,
162+
expected_partitions=[32, 1])
163+
164+
# max_partitions restricts partitioning of the variable.
165+
self._testMinMaxVariablePartitioner(max_partitions=16, axis=0,
166+
min_slice_size=256 << 10,
167+
var_name="v1_max",
168+
var_shape=[2048, 1024],
169+
expected_axis_shards=16,
170+
expected_partitions=[16, 1])
171+
self._testMinMaxVariablePartitioner(max_partitions=1, axis=0,
172+
min_slice_size=256 << 10,
173+
var_name="v2_max",
174+
var_shape=[2048, 1024],
175+
expected_axis_shards=1,
176+
expected_partitions=[1, 1])
177+
178+
# Reducing/Increasing min_slice_size proportionately increases/reduces the
179+
# number of partitions.
180+
self._testMinMaxVariablePartitioner(max_partitions=100, axis=0,
181+
min_slice_size=128 << 10,
182+
var_name="v3_slice",
183+
var_shape=[2048, 1024],
184+
expected_axis_shards=64,
185+
expected_partitions=[64, 1])
186+
self._testMinMaxVariablePartitioner(max_partitions=100, axis=0,
187+
min_slice_size=512 << 10,
188+
var_name="v4_slice",
189+
var_shape=[2048, 1024],
190+
expected_axis_shards=16,
191+
expected_partitions=[16, 1])
192+
193+
# Partitioning the variable along a different axis.
194+
self._testMinMaxVariablePartitioner(max_partitions=100, axis=1,
195+
min_slice_size=256 << 10,
196+
var_name="v5_axis",
197+
var_shape=[64, 1024, 1, 3],
198+
expected_axis_shards=3,
199+
expected_partitions=[1, 3, 1, 1])
200+
self._testMinMaxVariablePartitioner(max_partitions=100, axis=3,
201+
min_slice_size=256 << 10,
202+
var_name="v6_axis",
203+
var_shape=[64, 1024, 1, 3],
204+
expected_axis_shards=3,
205+
expected_partitions=[1, 1, 1, 3])
206+
207+
# Can not partition the variable more than what its shape allows.
208+
self._testMinMaxVariablePartitioner(max_partitions=100, axis=0,
209+
min_slice_size=256 << 10,
210+
var_name="v7_shape",
211+
var_shape=[16, 128, 1024],
212+
expected_axis_shards=16,
213+
expected_partitions=[16, 1, 1])
214+
self._testMinMaxVariablePartitioner(max_partitions=100, axis=0,
215+
min_slice_size=256 << 10,
216+
var_name="v8_shape",
217+
var_shape=[4, 512, 1024],
218+
expected_axis_shards=4,
219+
expected_partitions=[4, 1, 1])
220+
134221

135222
def _IotaInitializer(shape, dtype=tf.float32):
136223
assert dtype == tf.float32

tensorflow/python/ops/partitioned_variables.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,11 @@
6161
from tensorflow.python.ops import variable_scope
6262
from tensorflow.python.platform import tf_logging as logging
6363

64-
__all__ = ["create_partitioned_variables", "variable_axis_size_partitioner"]
64+
__all__ = [
65+
"create_partitioned_variables",
66+
"variable_axis_size_partitioner",
67+
"min_max_variable_partitioner",
68+
]
6569

6670

6771
def variable_axis_size_partitioner(
@@ -148,6 +152,69 @@ def _partitioner(shape, dtype):
148152
return _partitioner
149153

150154

155+
def min_max_variable_partitioner(max_partitions=1, axis=0,
156+
min_slice_size=256 << 10,
157+
bytes_per_string_element=16):
158+
"""Partitioner to allocate minimum size per slice.
159+
160+
Returns a partitioner that partitions the variable of given shape and dtype
161+
such that each partition has a minimum of `min_slice_size` slice of the
162+
variable. The maximum number of such partitions (upper bound) is given by
163+
`max_partitions`.
164+
165+
Args:
166+
max_partitions: Upper bound on the number of partitions. Defaults to 1.
167+
axis: Axis along which to partition the variable. Defaults to 0.
168+
min_slice_size: Minimum size of the variable slice per partition. Defaults
169+
to 256K.
170+
bytes_per_string_element: If the `Variable` is of type string, this provides
171+
an estimate of how large each scalar in the `Variable` is.
172+
173+
Returns:
174+
A partition function usable as the `partitioner` argument to
175+
`variable_scope`, `get_variable`, and `get_partitioned_variable_list`.
176+
177+
"""
178+
def _partitioner(shape, dtype):
179+
"""Partitioner that partitions list for a variable of given shape and type.
180+
181+
Ex: Consider partitioning a variable of type float32 with
182+
shape=[1024, 1024].
183+
If `max_partitions` >= 16, this function would return
184+
[(1024 * 1024 * 4) / (256 * 1024), 1] = [16, 1].
185+
If `max_partitions` < 16, this function would return
186+
[`max_partitions`, 1].
187+
188+
Args:
189+
shape: Shape of the variable.
190+
dtype: Type of the variable.
191+
192+
Returns:
193+
List of partitions for each axis (currently only one axis can be
194+
partitioned).
195+
196+
Raises:
197+
ValueError: If axis to partition along does not exist for the variable.
198+
"""
199+
if axis >= len(shape):
200+
raise ValueError("Can not partition variable along axis %d when shape is "
201+
"only %s" % (axis, shape))
202+
if dtype.base_dtype == dtypes.string:
203+
bytes_per_element = bytes_per_string_element
204+
else:
205+
bytes_per_element = dtype.size
206+
total_size_bytes = shape.num_elements() * bytes_per_element
207+
partitions = total_size_bytes / min_slice_size
208+
partitions_list = [1] * len(shape)
209+
# We can not partition the variable beyond what its shape or
210+
# `max_partitions` allows.
211+
partitions_list[axis] = max(1, min(shape[axis].value,
212+
max_partitions,
213+
int(math.ceil(partitions))))
214+
return partitions_list
215+
return _partitioner
216+
217+
151218
def create_partitioned_variables(
152219
shape, slicing, initializer, dtype=dtypes.float32,
153220
trainable=True, collections=None, name=None, reuse=None):

tensorflow/python/ops/state_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
## Variable Partitioners for Sharding
6969
7070
@@variable_axis_size_partitioner
71+
@@min_max_variable_partitioner
7172
7273
## Sparse Variable Updates
7374

0 commit comments

Comments
 (0)