From b8004b5de8249160efe0badee4d68908b359abe8 Mon Sep 17 00:00:00 2001 From: Harini <32316484+heavani@users.noreply.github.com> Date: Fri, 22 Mar 2019 14:14:16 -0700 Subject: [PATCH] Harini/tcn unit test (#351) * updated weight norm for TCN model * added sanity check unit test for tcn * minor fix - checking for None in gradient * Fixed style and pylint errors --- examples/np_semantic_segmentation/data.py | 4 + .../adding_problem/adding_model.py | 11 +- .../language_modeling_with_tcn.py | 2 +- .../models/temporal_convolutional_network.py | 220 ++++++++++++------ tests/test_server_sanity.py | 3 + tests/test_tcn.py | 49 ++++ 6 files changed, 214 insertions(+), 75 deletions(-) create mode 100644 tests/test_tcn.py diff --git a/examples/np_semantic_segmentation/data.py b/examples/np_semantic_segmentation/data.py index 8ef81ff6..a0912aaf 100644 --- a/examples/np_semantic_segmentation/data.py +++ b/examples/np_semantic_segmentation/data.py @@ -105,6 +105,7 @@ def expand_np_candidates(np, stemming): candidates.extend(get_all_case_combinations(np)) if stemming: # create all case-combinations of np's stem-> t-shirts to t-shirt etc. + # pylint: disable=no-member candidates.extend(get_all_case_combinations(fe.stem(np))) return candidates @@ -154,9 +155,12 @@ def prepare_data(data_file, output_file, word2vec_file, http_proxy=None, https_p """ # init_resources: global wordnet, wikidata, word2vec + # pylint: disable=no-member wordnet = fe.Wordnet() + # pylint: disable=no-member wikidata = fe.Wikidata(http_proxy, https_proxy) print("Start loading Word2Vec model (this might take a while...)") + # pylint: disable=no-member word2vec = fe.Word2Vec(word2vec_file) print("Finish loading feature extraction services") reader_list = read_csv_file_data(data_file) diff --git a/examples/word_language_model_with_tcn/adding_problem/adding_model.py b/examples/word_language_model_with_tcn/adding_problem/adding_model.py index 56f4d1b9..b98c900f 100644 --- a/examples/word_language_model_with_tcn/adding_problem/adding_model.py +++ b/examples/word_language_model_with_tcn/adding_problem/adding_model.py @@ -18,6 +18,7 @@ import tensorflow as tf from nlp_architect.models.temporal_convolutional_network import TCN +from tqdm import tqdm class TCNForAdding(TCN): @@ -45,7 +46,7 @@ def run(self, data_loader, num_iterations=1000, log_interval=100, result_dir="./ result_dir: str, path to results directory Returns: - None + float, Training loss of last iteration """ summary_writer = tf.summary.FileWriter(os.path.join(result_dir, "tfboard"), tf.get_default_graph()) @@ -55,7 +56,7 @@ def run(self, data_loader, num_iterations=1000, log_interval=100, result_dir="./ init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init) - for i in range(num_iterations): + for i in tqdm(range(num_iterations)): x_data, y_data = next(data_loader) @@ -64,7 +65,6 @@ def run(self, data_loader, num_iterations=1000, log_interval=100, result_dir="./ _, summary_train, total_loss_i = sess.run([self.training_update_step, self.merged_summary_op_train, self.training_loss], feed_dict=feed_dict) - summary_writer.add_summary(summary_train, i) if i % log_interval == 0: @@ -81,6 +81,8 @@ def run(self, data_loader, num_iterations=1000, log_interval=100, result_dir="./ print("Validation loss: {}".format(val_loss)) + return total_loss_i + # pylint: disable = arguments-differ def build_train_graph(self, lr, max_gradient_norm=None): """ @@ -114,7 +116,8 @@ def build_train_graph(self, lr, max_gradient_norm=None): params = tf.trainable_variables() gradients = tf.gradients(self.training_loss, params) if max_gradient_norm is not None: - clipped_gradients = [tf.clip_by_norm(t, max_gradient_norm) for t in gradients] + clipped_gradients = [t if t is None else tf.clip_by_norm(t, max_gradient_norm) + for t in gradients] else: clipped_gradients = gradients diff --git a/examples/word_language_model_with_tcn/mle_language_model/language_modeling_with_tcn.py b/examples/word_language_model_with_tcn/mle_language_model/language_modeling_with_tcn.py index d52e4e7c..3a43f6b4 100644 --- a/examples/word_language_model_with_tcn/mle_language_model/language_modeling_with_tcn.py +++ b/examples/word_language_model_with_tcn/mle_language_model/language_modeling_with_tcn.py @@ -115,7 +115,7 @@ def main(args): help='# of levels (default: 4)') PARSER.add_argument('--lr', type=float, default=4, action=check_size(0, 100), help='initial learning rate (default: 4)') -PARSER.add_argument('--nhid', type=int, default=600, action=check_size(0, 1000), +PARSER.add_argument('--nhid', type=int, default=600, action=check_size(0, 10000), help='number of hidden units per layer (default: 600)') PARSER.add_argument('--em_len', type=int, default=600, action=check_size(0, 10000), help='Length of embedding (default: 600)') diff --git a/nlp_architect/models/temporal_convolutional_network.py b/nlp_architect/models/temporal_convolutional_network.py index 75101420..d7b5d79c 100644 --- a/nlp_architect/models/temporal_convolutional_network.py +++ b/nlp_architect/models/temporal_convolutional_network.py @@ -19,78 +19,158 @@ import tensorflow as tf # pylint: disable=no-name-in-module from tensorflow.python.framework import tensor_shape -from tensorflow.python.layers import base -from tensorflow.python.layers import utils -from tensorflow.python.ops import nn_ops +from tensorflow.python.keras.layers import Wrapper from tensorflow.python.layers.convolutional import Conv1D -from tensorflow.python.keras import layers as keras_layers - - -class _ConvWeightNorm(keras_layers.Conv1D, base.Layer): +from tensorflow.python.ops import variable_scope +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.eager import context +from tensorflow.python.ops import nn_impl +from tensorflow.python.keras import initializers +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.ops import array_ops +from tensorflow.python.framework import ops + + +# ***NOTE***: The WeightNorm Class is copied from this PR: +# https://github.com/tensorflow/tensorflow/issues/14070 +# Once this becomes part of the official TF release, it will be removed +class WeightNorm(Wrapper): + """ This wrapper reparameterizes a layer by decoupling the weight's + magnitude and direction. This speeds up convergence by improving the + conditioning of the optimization problem. + + Weight Normalization: A Simple Reparameterization to Accelerate + Training of Deep Neural Networks: https://arxiv.org/abs/1602.07868 + Tim Salimans, Diederik P. Kingma (2016) + + WeightNorm wrapper works for keras and tf layers. + + ```python + net = WeightNorm(tf.keras.layers.Conv2D(2, 2, activation='relu'), + input_shape=(32, 32, 3), data_init=True)(x) + net = WeightNorm(tf.keras.layers.Conv2D(16, 5, activation='relu'), + data_init=True) + net = WeightNorm(tf.keras.layers.Dense(120, activation='relu'), + data_init=True)(net) + net = WeightNorm(tf.keras.layers.Dense(n_classes), + data_init=True)(net) + ``` + + Arguments: + layer: a layer instance. + data_init: If `True` use data dependent variable initialization + + Raises: + ValueError: If not initialized with a `Layer` instance. + ValueError: If `Layer` does not contain a `kernel` of weights + NotImplementedError: If `data_init` is True and running graph execution """ - Convolution base class that uses weight norm - """ - def __init__(self, *args, - **kwargs): - super(_ConvWeightNorm, self).__init__(*args, - **kwargs) - self.kernel_v = None - self.kernel_g = None - self.kernel = None - self.bias = None - self._convolution_op = None - + def __init__(self, layer, data_init=False, **kwargs): + if not isinstance(layer, Layer): + raise ValueError( + 'Please initialize `WeightNorm` layer with a ' + '`Layer` instance. You passed: {input}'.format(input=layer)) + + if not context.executing_eagerly() and data_init: + raise NotImplementedError( + 'Data dependent variable initialization is not available for ' + 'graph execution') + + self.initialized = True + if data_init: + self.initialized = False + + self.layer_depth = None + self.norm_axes = None + super(WeightNorm, self).__init__(layer, **kwargs) + self._track_checkpointable(layer, name='layer') + + def _compute_weights(self): + """Generate weights by combining the direction of weight vector + with it's norm """ + with variable_scope.variable_scope('compute_weights'): + self.layer.kernel = nn_impl.l2_normalize( + self.layer.v, axis=self.norm_axes) * self.layer.g + + def _init_norm(self, weights): + """Set the norm of the weight vector""" + from tensorflow.python.ops.linalg_ops import norm + with variable_scope.variable_scope('init_norm'): + # pylint: disable=no-member + flat = array_ops.reshape(weights, [-1, self.layer_depth]) + # pylint: disable=no-member + return array_ops.reshape(norm(flat, axis=0), (self.layer_depth,)) + + def _data_dep_init(self, inputs): + """Data dependent initialization for eager execution""" + from tensorflow.python.ops.nn import moments + from tensorflow.python.ops.math_ops import sqrt + + with variable_scope.variable_scope('data_dep_init'): + # Generate data dependent init values + activation = self.layer.activation + self.layer.activation = None + x_init = self.layer.call(inputs) + m_init, v_init = moments(x_init, self.norm_axes) + scale_init = 1. / sqrt(v_init + 1e-10) + + # Assign data dependent init values + self.layer.g = self.layer.g * scale_init + self.layer.bias = (-1 * m_init * scale_init) + self.layer.activation = activation + self.initialized = True + + # pylint: disable=signature-differs def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape) - if self.data_format == 'channels_first': - channel_axis = 1 - else: - channel_axis = -1 - # pylint: disable=no-member - if input_shape[channel_axis].value is None: - raise ValueError('The channel dimension of the inputs ' - 'should be defined. Found `None`.') - # pylint: disable=no-member - input_dim = input_shape[channel_axis].value - kernel_shape = self.kernel_size + (input_dim, self.filters) - - # The variables defined below are specific to the weight normed conv class - self.kernel_v = self.add_variable(name='kernel_v', - shape=kernel_shape, - initializer=self.kernel_initializer, - regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint, - trainable=True, - dtype=self.dtype) - self.kernel_g = self.add_variable(name='kernel_g', shape=[], trainable=True, - dtype=self.dtype) - self.kernel = self.kernel_g * tf.nn.l2_normalize(self.kernel_v) - - if self.use_bias: - self.bias = self.add_variable(name='bias', - shape=(self.filters,), - initializer=self.bias_initializer, - regularizer=self.bias_regularizer, - constraint=self.bias_constraint, - trainable=True, - dtype=self.dtype) - else: - self.bias = None - self.input_spec = base.InputSpec(ndim=self.rank + 2, - axes={channel_axis: input_dim}) - self._convolution_op = nn_ops.Convolution( - input_shape, - filter_shape=self.kernel.get_shape(), - dilation_rate=self.dilation_rate, - strides=self.strides, - padding=self.padding.upper(), - data_format=utils.convert_data_format(self.data_format, - self.rank + 2)) + """Build `Layer`""" + input_shape = tensor_shape.TensorShape(input_shape).as_list() + self.input_spec = InputSpec(shape=input_shape) + + if not self.layer.built: + self.layer.build(input_shape) + self.layer.built = False + + if not hasattr(self.layer, 'kernel'): + raise ValueError( + '`WeightNorm` must wrap a layer that' + ' contains a `kernel` for weights' + ) + + # The kernel's filter or unit dimension is -1 + self.layer_depth = int(self.layer.kernel.shape[-1]) + self.norm_axes = list(range(self.layer.kernel.shape.ndims - 1)) + + self.layer.v = self.layer.kernel + self.layer.g = self.layer.add_variable( + name="g", + shape=(self.layer_depth,), + initializer=initializers.get('ones'), + dtype=self.layer.kernel.dtype, + trainable=True) + + with ops.control_dependencies([self.layer.g.assign( + self._init_norm(self.layer.v))]): + self._compute_weights() + + self.layer.built = True + + super(WeightNorm, self).build() self.built = True + # pylint: disable=arguments-differ + def call(self, inputs): + """Call `Layer`""" + if context.executing_eagerly(): + if not self.initialized: + self._data_dep_init(inputs) + self._compute_weights() # Recompute weights for each forward pass + + output = self.layer.call(inputs) + return output -# re-orient the Conv1D class to point to the weight norm version of conv base class -Conv1D.__bases__ = (_ConvWeightNorm,) + def compute_output_shape(self, input_shape): + return tensor_shape.TensorShape( + self.layer.compute_output_shape(input_shape).as_list()) class TCN: @@ -253,10 +333,10 @@ def _dilated_causal_conv(self, x, n_filters, dilation, padding): with tf.variable_scope("dilated_causal_conv"): # define dilated convolution layer with left side padding x = tf.pad(x, tf.constant([[0, 0], [padding, 0], [0, 0]]), 'CONSTANT') - x = Conv1D(filters=n_filters, kernel_size=self.kernel_size, padding='valid', strides=1, - activation=None, dilation_rate=dilation, - kernel_initializer=tf.initializers.random_normal(0, 0.01), - bias_initializer=tf.initializers.random_normal(0, 0.01))(x) + x = WeightNorm(Conv1D(filters=n_filters, kernel_size=self.kernel_size, padding='valid', + strides=1, activation=None, dilation_rate=dilation, + kernel_initializer=tf.initializers.random_normal(0, 0.01), + bias_initializer=tf.initializers.random_normal(0, 0.01)))(x) assert x.shape[1].value == input_width diff --git a/tests/test_server_sanity.py b/tests/test_server_sanity.py index 4695a3c0..1ce909a3 100644 --- a/tests/test_server_sanity.py +++ b/tests/test_server_sanity.py @@ -122,6 +122,7 @@ def test_request(service_name): myHeaders = headers.copy() myHeaders["content-type"] = "application/json" myHeaders["Response-Format"] = "json" + # pylint: disable=no-member response = hug.test.post(api, '/inference', body=doc, headers=myHeaders) assert_response_struct(response.data, json.loads(expected_result)) @@ -139,6 +140,7 @@ def test_gzip_file_request(service_name): myHeaders["content-type"] = "application/gzip" myHeaders["Response-Format"] = "gzip" myHeaders["content-encoding"] = "gzip" + # pylint: disable=no-member response = hug.test.post(api, '/inference', body=doc, headers=myHeaders) result_doc = get_decompressed_gzip(response.data) assert_response_struct(result_doc, json.loads(expected_result)) @@ -155,6 +157,7 @@ def test_json_file_request(service_name): myHeaders = headers.copy() myHeaders["Content-Type"] = "application/json" myHeaders["RESPONSE-FORMAT"] = "json" + # pylint: disable=no-member response = hug.test.post(nlp_architect.server.serve, '/inference', body=doc, headers=myHeaders) assert_response_struct(response.data, json.loads(expected_result)) assert response.status == hug.HTTP_OK diff --git a/tests/test_tcn.py b/tests/test_tcn.py new file mode 100644 index 00000000..7daef708 --- /dev/null +++ b/tests/test_tcn.py @@ -0,0 +1,49 @@ +# ****************************************************************************** +# Copyright 2017-2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ****************************************************************************** +from examples.word_language_model_with_tcn.adding_problem.adding_model import TCNForAdding +from examples.word_language_model_with_tcn.toy_data.adding import Adding + + +def test_tcn_adding(): + """ + Sanity check - + Test function checks to make sure training loss drops to ~0 on small dummy dataset + """ + n_features = 2 + hidden_sizes = [64] * 3 + kernel_size = 3 + dropout = 0.0 + seq_len = 10 + n_train = 5000 + n_val = 100 + batch_size = 32 + n_epochs = 10 + num_iterations = int(n_train * n_epochs * 1.0 / batch_size) + lr = 0.002 + grad_clip_value = 10 + results_dir = "./" + + adding_dataset = Adding(seq_len=seq_len, n_train=n_train, n_test=n_val) + + model = TCNForAdding(seq_len, n_features, hidden_sizes, kernel_size=kernel_size, + dropout=dropout) + + model.build_train_graph(lr, max_gradient_norm=grad_clip_value) + + training_loss = model.run(adding_dataset, num_iterations=num_iterations, log_interval=1e6, + result_dir=results_dir) + + assert training_loss < 1e-3