Skip to content

Commit

Permalink
tf backend supports bool variable (keras-team#7832)
Browse files Browse the repository at this point in the history
* ENH: use tf.as_dtype to convert dtype

* TST: add unit test

* ENH: keep compatible

* DOC: add argument desc

* TST: more test cases

* TST: only test on tensorflow

* CLN: replace _convert_string_dtype by tf.as_dtype

* BUG: random_normal_variable should check dtype

* CLN: join lines

* DOC: add raise section

* CLN: remove type check
  • Loading branch information
facaiy authored and fchollet committed Sep 9, 2017
1 parent 0b3dc16 commit ba29c60
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 38 deletions.
40 changes: 7 additions & 33 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,32 +180,6 @@ def set_session(session):

# VARIABLE MANIPULATION

def _convert_string_dtype(dtype):
"""Get the type from a string.
# Arguments
dtype: A string representation of a type.
# Returns
The type requested.
# Raises
ValueError: if `dtype` is not supported.
"""
mapping = {'float16': tf.float16,
'float32': tf.float32,
'float64': tf.float64,
'int16': tf.int16,
'int32': tf.int32,
'int64': tf.int64,
'uint8': tf.int8,
'uint16': tf.uint16}

if dtype not in mapping:
raise ValueError('Unsupported dtype:', dtype)
return mapping[dtype]


def _to_tensor(x, dtype):
"""Convert the input `x` to a tensor of type `dtype`.
Expand Down Expand Up @@ -313,7 +287,7 @@ def variable(value, dtype=None, name=None, constraint=None):
v._keras_shape = sparse_coo.shape
v._uses_learning_phase = False
return v
v = tf.Variable(value, dtype=_convert_string_dtype(dtype), name=name)
v = tf.Variable(value, dtype=tf.as_dtype(dtype), name=name)
if isinstance(value, np.ndarray):
v._keras_shape = value.shape
elif hasattr(value, 'get_shape'):
Expand Down Expand Up @@ -621,7 +595,7 @@ def zeros(shape, dtype=None, name=None):
"""
if dtype is None:
dtype = floatx()
tf_dtype = _convert_string_dtype(dtype)
tf_dtype = tf.as_dtype(dtype)
return variable(tf.constant_initializer(0., dtype=tf_dtype)(shape),
dtype, name)

Expand Down Expand Up @@ -649,7 +623,7 @@ def ones(shape, dtype=None, name=None):
"""
if dtype is None:
dtype = floatx()
tf_dtype = _convert_string_dtype(dtype)
tf_dtype = tf.as_dtype(dtype)
return variable(tf.constant_initializer(1., dtype=tf_dtype)(shape),
dtype, name)

Expand Down Expand Up @@ -769,7 +743,7 @@ def random_uniform_variable(shape, low, high, dtype=None,
"""
if dtype is None:
dtype = floatx()
tf_dtype = _convert_string_dtype(dtype)
tf_dtype = tf.as_dtype(dtype)
if seed is None:
# ensure that randomness is conditioned by the Numpy RNG
seed = np.random.randint(10e8)
Expand Down Expand Up @@ -806,7 +780,7 @@ def random_normal_variable(shape, mean, scale, dtype=None,
"""
if dtype is None:
dtype = floatx()
tf_dtype = _convert_string_dtype(dtype)
tf_dtype = tf.as_dtype(dtype)
if seed is None:
# ensure that randomness is conditioned by the Numpy RNG
seed = np.random.randint(10e8)
Expand Down Expand Up @@ -2154,7 +2128,7 @@ def set_value(x, value):
(of the same shape).
"""
value = np.asarray(value, dtype=dtype(x))
tf_dtype = _convert_string_dtype(x.dtype.name.split('_')[0])
tf_dtype = tf.as_dtype(x.dtype.name.split('_')[0])
if hasattr(x, '_assign_placeholder'):
assign_placeholder = x._assign_placeholder
assign_op = x._assign_op
Expand All @@ -2178,7 +2152,7 @@ def batch_set_value(tuples):
feed_dict = {}
for x, value in tuples:
value = np.asarray(value, dtype=dtype(x))
tf_dtype = _convert_string_dtype(x.dtype.name.split('_')[0])
tf_dtype = tf.as_dtype(x.dtype.name.split('_')[0])
if hasattr(x, '_assign_placeholder'):
assign_placeholder = x._assign_placeholder
assign_op = x._assign_op
Expand Down
13 changes: 8 additions & 5 deletions tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,6 @@ def test_random_variables(self):
mean=0., scale=1.,
shape_or_val=False, assert_value_equality=False)

# not supported dtype
for dtype in ['int16', 'int32', 'int64', 'uint8', 'uint16', 'double']:
with pytest.raises(ValueError):
ztf = KTF.random_normal_variable((2, 3), 0, 1, dtype=dtype)

@pytest.mark.parametrize('k', [KTF], ids=['TensorFlow'])
def test_batch_dot_shape(self, k):
x_batch = k.ones(shape=(32, 20))
Expand Down Expand Up @@ -1397,6 +1392,14 @@ def test_set_floatx(self):
# Restore old value
set_floatx(old_floatx)

def test_variable_support_bool_dtype(self):
# Github issue: 7819
if K.backend() == 'tensorflow':
assert K.dtype(K.variable(1, dtype='int16')) == 'int16'
assert K.dtype(K.variable(False, dtype='bool')) == 'bool'
with pytest.raises(TypeError):
K.variable('', dtype='unsupported')


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit ba29c60

Please sign in to comment.