Skip to content

Commit

Permalink
Improve TF backend's Switch function (keras-team#7958)
Browse files Browse the repository at this point in the history
* improve k.switch

* update doc

* handle case when ndim is None

* pep8

* use tf.reshape instead of tf.expand_dims

* unit tests + better error msg

* remove faulty cases

* add broadcasting to cntk backend

* typo

* typo

* update ds

* update ds

* cntk doesnt support symbolic tile

* cntk doesnt support symbolic tile

* allow int n for cntk tile

* int->tuple

* bug fix cntk tile

* fix broadcasting in theano backend

* formatting fixes
  • Loading branch information
farizrahman4u authored and fchollet committed Sep 29, 2017
1 parent 393af21 commit 1bbfdb6
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 16 deletions.
20 changes: 18 additions & 2 deletions keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,14 +674,16 @@ def squeeze(x, axis):


def tile(x, n):
if isinstance(n, list):
if isinstance(n, int):
n = (n,)
elif isinstance(n, list):
n = tuple(n)

shape = int_shape(x)
num_dynamic_axis = _get_dynamic_axis_num(x)
# Padding the axis
if len(n) < len(shape):
n = tuple([None for _ in range(len(shape) - len(n))]) + n
n = tuple([1 for _ in range(len(shape) - len(n))]) + n

if len(n) != len(shape):
raise NotImplementedError
Expand Down Expand Up @@ -2063,6 +2065,20 @@ def stop_gradient(variables):


def switch(condition, then_expression, else_expression):
ndim_cond = ndim(condition)
ndim_expr = ndim(then_expression)
if ndim_cond > ndim_expr:
raise ValueError('Rank of condition should be less'
' than or equal to rank of then and'
' else expressions. ndim(condition)=' +
str(cond_ndim) + ', ndim(then_expression)'
'=' + str(expr_ndim))
elif ndim_cond < ndim_expr:
shape_expr = int_shape(then_expression)
ndim_diff = ndim_expr - ndim_cond
for i in range(ndim_diff):
condition = expand_dims(condition)
condition = tile(condition, shape_expr[ndim_cond + i])
return C.element_select(condition,
then_expression,
else_expression)
Expand Down
55 changes: 42 additions & 13 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2561,28 +2561,57 @@ def switch(condition, then_expression, else_expression):
should be symbolic tensors of the *same shape*.
# Arguments
condition: scalar tensor (`int` or `bool`).
condition: tensor (`int` or `bool`).
then_expression: either a tensor, or a callable that returns a tensor.
else_expression: either a tensor, or a callable that returns a tensor.
# Returns
The selected tensor.
# Raises
ValueError: If rank of `condition` is greater than rank of expressions.
"""
if condition.dtype != tf.bool:
condition = tf.cast(condition, 'bool')
if not callable(then_expression):
def then_expression_fn():
return then_expression
else:
then_expression_fn = then_expression
if not callable(else_expression):
def else_expression_fn():
return else_expression
cond_ndim = ndim(condition)
if not cond_ndim:
if not callable(then_expression):
def then_expression_fn():
return then_expression
else:
then_expression_fn = then_expression
if not callable(else_expression):
def else_expression_fn():
return else_expression
else:
else_expression_fn = else_expression
x = tf.cond(condition,
then_expression_fn,
else_expression_fn)
else:
else_expression_fn = else_expression
x = tf.cond(condition,
then_expression_fn,
else_expression_fn)
# tf.where needs its condition tensor
# to be the same shape as its two
# result tensors
if callable(then_expression):
then_expression = then_expression()
if callable(else_expression):
else_expression = else_expression()
expr_ndim = ndim(then_expression)
if cond_ndim > expr_ndim:
raise ValueError('Rank of `condition` should be less than or'
' equal to rank of `then_expression` and '
'`else_expression`. ndim(condition)=' +
str(cond_ndim) + ', ndim(then_expression)'
'=' + str(expr_ndim))
if cond_ndim > 1:
ndim_diff = expr_ndim - cond_ndim
cond_shape = tf.concat([tf.shape(condition), [1] * ndim_diff], axis=0)
condition = tf.reshape(condition, cond_shape)
expr_shape = tf.shape(then_expression)
shape_diff = expr_shape - cond_shape
tile_shape = tf.where(shape_diff > 0, expr_shape, tf.ones_like(expr_shape))
condition = tf.tile(condition, tile_shape)
x = tf.where(condition, then_expression, else_expression)
return x


Expand Down
6 changes: 6 additions & 0 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,12 @@ def switch(condition, then_expression, else_expression):
then_expression = then_expression()
if callable(else_expression):
else_expression = else_expression()
cond_ndim = ndim(condition)
expr_ndim = ndim(then_expression)
if cond_ndim < expr_ndim:
ndim_diff = expr_ndim - cond_ndim
for _ in range(ndim_diff):
condition = expand_dims(condition)
return T.switch(condition, then_expression, else_expression)


Expand Down
15 changes: 14 additions & 1 deletion tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,14 +648,27 @@ def test_logsumexp_optim(self):
rtol=1e-5)

def test_switch(self):
# scalar
val = np.random.random()
z_list = []
for k in BACKENDS:
x = k.variable(val)
x = k.switch(k.greater_equal(x, 0.5), x * 0.1, x * 0.2)
z_list.append(k.eval(x))

assert_list_pairwise(z_list)
# non scalar
shapes = []
shapes.append([(4, 3, 2), (4, 3, 2), (4, 3, 2)])
shapes.append([(4, 3,), (4, 3, 2), (4, 3, 2)])
shapes.append([(4,), (4, 3, 2), (4, 3, 2)])
for s in shapes:
z_list = []
arrays = list(map(np.random.random, s))
for k in BACKENDS:
x, then_expr, else_expr = map(k.variable, arrays)
cond = k.greater_equal(x, 0.5)
z_list.append(k.eval(k.switch(cond, then_expr, else_expr)))
assert_list_pairwise(z_list)

def test_dropout(self):
val = np.random.random((100, 100))
Expand Down

0 comments on commit 1bbfdb6

Please sign in to comment.