diff --git a/keras/backend/cntk_backend.py b/keras/backend/cntk_backend.py index 217237a63e9..830afd7a475 100644 --- a/keras/backend/cntk_backend.py +++ b/keras/backend/cntk_backend.py @@ -674,14 +674,16 @@ def squeeze(x, axis): def tile(x, n): - if isinstance(n, list): + if isinstance(n, int): + n = (n,) + elif isinstance(n, list): n = tuple(n) shape = int_shape(x) num_dynamic_axis = _get_dynamic_axis_num(x) # Padding the axis if len(n) < len(shape): - n = tuple([None for _ in range(len(shape) - len(n))]) + n + n = tuple([1 for _ in range(len(shape) - len(n))]) + n if len(n) != len(shape): raise NotImplementedError @@ -2063,6 +2065,20 @@ def stop_gradient(variables): def switch(condition, then_expression, else_expression): + ndim_cond = ndim(condition) + ndim_expr = ndim(then_expression) + if ndim_cond > ndim_expr: + raise ValueError('Rank of condition should be less' + ' than or equal to rank of then and' + ' else expressions. ndim(condition)=' + + str(cond_ndim) + ', ndim(then_expression)' + '=' + str(expr_ndim)) + elif ndim_cond < ndim_expr: + shape_expr = int_shape(then_expression) + ndim_diff = ndim_expr - ndim_cond + for i in range(ndim_diff): + condition = expand_dims(condition) + condition = tile(condition, shape_expr[ndim_cond + i]) return C.element_select(condition, then_expression, else_expression) diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index 69736fc33f3..75b5abcf560 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -2561,28 +2561,57 @@ def switch(condition, then_expression, else_expression): should be symbolic tensors of the *same shape*. # Arguments - condition: scalar tensor (`int` or `bool`). + condition: tensor (`int` or `bool`). then_expression: either a tensor, or a callable that returns a tensor. else_expression: either a tensor, or a callable that returns a tensor. # Returns The selected tensor. + + # Raises + ValueError: If rank of `condition` is greater than rank of expressions. """ if condition.dtype != tf.bool: condition = tf.cast(condition, 'bool') - if not callable(then_expression): - def then_expression_fn(): - return then_expression - else: - then_expression_fn = then_expression - if not callable(else_expression): - def else_expression_fn(): - return else_expression + cond_ndim = ndim(condition) + if not cond_ndim: + if not callable(then_expression): + def then_expression_fn(): + return then_expression + else: + then_expression_fn = then_expression + if not callable(else_expression): + def else_expression_fn(): + return else_expression + else: + else_expression_fn = else_expression + x = tf.cond(condition, + then_expression_fn, + else_expression_fn) else: - else_expression_fn = else_expression - x = tf.cond(condition, - then_expression_fn, - else_expression_fn) + # tf.where needs its condition tensor + # to be the same shape as its two + # result tensors + if callable(then_expression): + then_expression = then_expression() + if callable(else_expression): + else_expression = else_expression() + expr_ndim = ndim(then_expression) + if cond_ndim > expr_ndim: + raise ValueError('Rank of `condition` should be less than or' + ' equal to rank of `then_expression` and ' + '`else_expression`. ndim(condition)=' + + str(cond_ndim) + ', ndim(then_expression)' + '=' + str(expr_ndim)) + if cond_ndim > 1: + ndim_diff = expr_ndim - cond_ndim + cond_shape = tf.concat([tf.shape(condition), [1] * ndim_diff], axis=0) + condition = tf.reshape(condition, cond_shape) + expr_shape = tf.shape(then_expression) + shape_diff = expr_shape - cond_shape + tile_shape = tf.where(shape_diff > 0, expr_shape, tf.ones_like(expr_shape)) + condition = tf.tile(condition, tile_shape) + x = tf.where(condition, then_expression, else_expression) return x diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index d4bbcdf2e30..3fc4e8ac527 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -1457,6 +1457,12 @@ def switch(condition, then_expression, else_expression): then_expression = then_expression() if callable(else_expression): else_expression = else_expression() + cond_ndim = ndim(condition) + expr_ndim = ndim(then_expression) + if cond_ndim < expr_ndim: + ndim_diff = expr_ndim - cond_ndim + for _ in range(ndim_diff): + condition = expand_dims(condition) return T.switch(condition, then_expression, else_expression) diff --git a/tests/keras/backend/backend_test.py b/tests/keras/backend/backend_test.py index e45e72c5774..9a4bdf807d9 100644 --- a/tests/keras/backend/backend_test.py +++ b/tests/keras/backend/backend_test.py @@ -648,14 +648,27 @@ def test_logsumexp_optim(self): rtol=1e-5) def test_switch(self): + # scalar val = np.random.random() z_list = [] for k in BACKENDS: x = k.variable(val) x = k.switch(k.greater_equal(x, 0.5), x * 0.1, x * 0.2) z_list.append(k.eval(x)) - assert_list_pairwise(z_list) + # non scalar + shapes = [] + shapes.append([(4, 3, 2), (4, 3, 2), (4, 3, 2)]) + shapes.append([(4, 3,), (4, 3, 2), (4, 3, 2)]) + shapes.append([(4,), (4, 3, 2), (4, 3, 2)]) + for s in shapes: + z_list = [] + arrays = list(map(np.random.random, s)) + for k in BACKENDS: + x, then_expr, else_expr = map(k.variable, arrays) + cond = k.greater_equal(x, 0.5) + z_list.append(k.eval(k.switch(cond, then_expr, else_expr))) + assert_list_pairwise(z_list) def test_dropout(self): val = np.random.random((100, 100))