Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Breaking down the attention API PR: part 2 #11140

Merged
merged 1 commit into from
Sep 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions keras/engine/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,8 @@ def _base_init(self, name=None):
def _init_graph_network(self, inputs, outputs, name=None):
self._uses_inputs_arg = True
# Normalize and set self.inputs, self.outputs.
if isinstance(inputs, (list, tuple)):
self.inputs = list(inputs) # Tensor or list of tensors.
else:
self.inputs = [inputs]
if isinstance(outputs, (list, tuple)):
self.outputs = list(outputs)
else:
self.outputs = [outputs]
self.inputs = to_list(inputs, allow_tuple=True)
self.outputs = to_list(outputs, allow_tuple=True)

# User-provided argument validation.
# Check for redundancy in inputs.
Expand Down
18 changes: 4 additions & 14 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,10 +596,7 @@ def _set_inputs(self, inputs, outputs=None, training=None):
self._feed_inputs = []
self._feed_input_names = []
self._feed_input_shapes = []
if isinstance(inputs, (list, tuple)):
inputs = list(inputs)
else:
inputs = [inputs]
inputs = to_list(inputs, allow_tuple=True)

for i, v in enumerate(inputs):
name = 'input_%d' % (i + 1)
Expand Down Expand Up @@ -633,10 +630,7 @@ def _set_inputs(self, inputs, outputs=None, training=None):
outputs = self.call(unpack_singleton(self.inputs), training=training)
else:
outputs = self.call(unpack_singleton(self.inputs))
if isinstance(outputs, (list, tuple)):
outputs = list(outputs)
else:
outputs = [outputs]
outputs = to_list(outputs, allow_tuple=True)
self.outputs = outputs
self.output_names = [
'output_%d' % (i + 1) for i in range(len(self.outputs))]
Expand Down Expand Up @@ -704,10 +698,7 @@ def _standardize_user_data(self, x,
'You passed: y=' + str(y))
# Typecheck that all inputs are *either* value *or* symbolic.
if y is not None:
if isinstance(y, (list, tuple)):
all_inputs += list(y)
else:
all_inputs.append(y)
all_inputs += to_list(y, allow_tuple=True)
if any(K.is_tensor(v) for v in all_inputs):
if not all(K.is_tensor(v) for v in all_inputs):
raise ValueError('Do not pass inputs that mix Numpy '
Expand All @@ -716,8 +707,7 @@ def _standardize_user_data(self, x,
'; y=' + str(y))

# Handle target tensors if any passed.
if not isinstance(y, (list, tuple)):
y = [y]
y = to_list(y, allow_tuple=True)
target_tensors = [v for v in y if K.is_tensor(v)]
if not target_tensors:
target_tensors = None
Expand Down
5 changes: 2 additions & 3 deletions keras/layers/advanced_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..engine.base_layer import InputSpec
from .. import backend as K
from ..legacy import interfaces
from ..utils.generic_utils import to_list


class LeakyReLU(Layer):
Expand Down Expand Up @@ -100,10 +101,8 @@ def __init__(self, alpha_initializer='zeros',
self.alpha_constraint = constraints.get(alpha_constraint)
if shared_axes is None:
self.shared_axes = None
elif not isinstance(shared_axes, (list, tuple)):
self.shared_axes = [shared_axes]
else:
self.shared_axes = list(shared_axes)
self.shared_axes = to_list(shared_axes, allow_tuple=True)

def build(self, input_shape):
param_shape = list(input_shape[1:])
Expand Down
9 changes: 3 additions & 6 deletions keras/layers/convolutional_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..legacy.layers import Recurrent, ConvRecurrent2D
from .recurrent import RNN
from ..utils.generic_utils import has_arg
from ..utils.generic_utils import to_list
from ..utils.generic_utils import transpose_shape


Expand Down Expand Up @@ -387,10 +388,7 @@ def step(inputs, states):
output._uses_learning_phase = True

if self.return_state:
if not isinstance(states, (list, tuple)):
states = [states]
else:
states = list(states)
states = to_list(states, allow_tuple=True)
return [output] + states
else:
return output
Expand Down Expand Up @@ -443,8 +441,7 @@ def get_tuple_shape(nb_channels):
K.set_value(self.states[0],
np.zeros(get_tuple_shape(self.cell.state_size)))
else:
if not isinstance(states, (list, tuple)):
states = [states]
states = to_list(states, allow_tuple=True)
if len(states) != len(self.states):
raise ValueError('Layer ' + self.name + ' expects ' +
str(len(self.states)) + ' states, '
Expand Down
6 changes: 2 additions & 4 deletions keras/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .. import constraints
from ..engine.base_layer import Layer
from ..legacy import interfaces
from ..utils.generic_utils import to_list


class Embedding(Layer):
Expand Down Expand Up @@ -117,10 +118,7 @@ def compute_output_shape(self, input_shape):
return input_shape + (self.output_dim,)
else:
# input_length can be tuple if input is 3D or higher
if isinstance(self.input_length, (list, tuple)):
in_lens = list(self.input_length)
else:
in_lens = [self.input_length]
in_lens = to_list(self.input_length, allow_tuple=True)
if len(in_lens) != len(input_shape) - 1:
raise ValueError('"input_length" is %s, but received input has shape %s' %
(str(self.input_length), str(input_shape)))
Expand Down
9 changes: 3 additions & 6 deletions keras/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..engine.base_layer import Layer
from ..engine.base_layer import InputSpec
from ..utils.generic_utils import has_arg
from ..utils.generic_utils import to_list

# Legacy support.
from ..legacy.layers import Recurrent
Expand Down Expand Up @@ -664,10 +665,7 @@ def step(inputs, states):
state._uses_learning_phase = True

if self.return_state:
if not isinstance(states, (list, tuple)):
states = [states]
else:
states = list(states)
states = to_list(states, allow_tuple=True)
return [output] + states
else:
return output
Expand Down Expand Up @@ -702,8 +700,7 @@ def reset_states(self, states=None):
K.set_value(self.states[0],
np.zeros((batch_size, self.cell.state_size)))
else:
if not isinstance(states, (list, tuple)):
states = [states]
states = to_list(states, allow_tuple=True)
if len(states) != len(self.states):
raise ValueError('Layer ' + self.name + ' expects ' +
str(len(self.states)) + ' states, '
Expand Down
11 changes: 3 additions & 8 deletions keras/legacy/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,7 @@ def __call__(self, inputs, initial_state=None, **kwargs):
if initial_state is None:
return super(Recurrent, self).__call__(inputs, **kwargs)

if not isinstance(initial_state, (list, tuple)):
initial_state = [initial_state]
initial_state = to_list(initial_state, allow_tuple=True)

is_keras_tensor = hasattr(initial_state[0], '_keras_history')
for tensor in initial_state:
Expand Down Expand Up @@ -602,10 +601,7 @@ def call(self, inputs, mask=None, training=None, initial_state=None):
output = last_output

if self.return_state:
if not isinstance(states, (list, tuple)):
states = [states]
else:
states = list(states)
states = to_list(states, allow_tuple=True)
return [output] + states
else:
return output
Expand Down Expand Up @@ -633,8 +629,7 @@ def reset_states(self, states=None):
for state in self.states:
K.set_value(state, np.zeros((batch_size, self.units)))
else:
if not isinstance(states, (list, tuple)):
states = [states]
states = to_list(states, allow_tuple=True)
if len(states) != len(self.states):
raise ValueError('Layer ' + self.name + ' expects ' +
str(len(self.states)) + ' states, '
Expand Down
13 changes: 8 additions & 5 deletions keras/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,20 +444,26 @@ def add(self, n, values=None):
self.update(self._seen_so_far + n, values)


def to_list(x):
def to_list(x, allow_tuple=False):
"""Normalizes a list/tensor into a list.

If a tensor is passed, we return
a list of size 1 containing the tensor.

# Arguments
x: target object to be normalized.
allow_tuple: If False and x is a tuple,
it will be converted into a list
with a single element (the tuple).
Else converts the tuple to a list.

# Returns
A list.
"""
if isinstance(x, list):
return x
if allow_tuple and isinstance(x, tuple):
return list(x)
return [x]


Expand All @@ -483,10 +489,7 @@ def object_list_uid(object_list):


def is_all_none(iterable_or_element):
if not isinstance(iterable_or_element, (list, tuple)):
iterable = [iterable_or_element]
else:
iterable = iterable_or_element
iterable = to_list(iterable_or_element, allow_tuple=True)
for element in iterable:
if element is not None:
return False
Expand Down