|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "code", |
| 5 | + "execution_count": null, |
| 6 | + "metadata": {}, |
| 7 | + "outputs": [], |
| 8 | + "source": [ |
| 9 | + "def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, parse_record_fn, num_epochs):\n", |
| 10 | + " # make the dataset prefetchable for parallellism\n", |
| 11 | + " dataset = dataset.prefetch(buffer_size=batch_size)\n", |
| 12 | + " \n", |
| 13 | + " # shuffle dataset\n", |
| 14 | + " if is_training:\n", |
| 15 | + " dataset = dataset.shuffle(buffer_size=shuffle_buffer)\n", |
| 16 | + " \n", |
| 17 | + " # repeat shuffled dataset for multi-epoch training\n", |
| 18 | + " dataset = dataset.repeat(num_epochs)\n", |
| 19 | + "\n", |
| 20 | + " # Parse the raw records into images and labels and batch them\n", |
| 21 | + " dataset = dataset.map(lambda x : parse_record_fn(x, is_training), num_parallel_calls=1) \n", |
| 22 | + " dataset = dataset.batch(batch_size)\n", |
| 23 | + " \n", |
| 24 | + " # prefetch one batch at a time\n", |
| 25 | + " dataset.prefetch(1)\n", |
| 26 | + "\n", |
| 27 | + " return dataset" |
| 28 | + ] |
| 29 | + }, |
| 30 | + { |
| 31 | + "cell_type": "code", |
| 32 | + "execution_count": null, |
| 33 | + "metadata": {}, |
| 34 | + "outputs": [], |
| 35 | + "source": [ |
| 36 | + "def learning_schedule(batch_size, batch_denom, num_images, boundary_epochs, decay_rates):\n", |
| 37 | + " initial_learning_rate = 0.1 * batch_size / batch_denom\n", |
| 38 | + " batches_per_epoch = num_images / batch_size\n", |
| 39 | + "\n", |
| 40 | + " # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.\n", |
| 41 | + " boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]\n", |
| 42 | + " vals = [initial_learning_rate * decay for decay in decay_rates]\n", |
| 43 | + "\n", |
| 44 | + " # a global step means running an optimization op on a batch\n", |
| 45 | + " def learning_rate_fn(global_step):\n", |
| 46 | + " global_step = tf.cast(global_step, tf.int32)\n", |
| 47 | + " return tf.train.piecewise_constant(global_step, boundaries, vals)\n", |
| 48 | + "\n", |
| 49 | + " return learning_rate_fn" |
| 50 | + ] |
| 51 | + }, |
| 52 | + { |
| 53 | + "cell_type": "code", |
| 54 | + "execution_count": null, |
| 55 | + "metadata": {}, |
| 56 | + "outputs": [], |
| 57 | + "source": [ |
| 58 | + "def resnet_model_fn(features, labels, mode, model_class,\n", |
| 59 | + " resnet_size, weight_decay, learning_rate_fn, momentum,\n", |
| 60 | + " data_format, resnet_version, loss_scale,\n", |
| 61 | + " loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE):\n", |
| 62 | + " \"\"\"Shared functionality for different resnet model_fns.\n", |
| 63 | + " Initializes the ResnetModel representing the model layers\n", |
| 64 | + " and uses that model to build the necessary EstimatorSpecs for\n", |
| 65 | + " the `mode` in question. For training, this means building losses,\n", |
| 66 | + " the optimizer, and the train op that get passed into the EstimatorSpec.\n", |
| 67 | + " For evaluation and prediction, the EstimatorSpec is returned without\n", |
| 68 | + " a train op, but with the necessary parameters for the given mode.\n", |
| 69 | + " Args:\n", |
| 70 | + " features: tensor representing input images\n", |
| 71 | + " labels: tensor representing class labels for all input images\n", |
| 72 | + " mode: current estimator mode; should be one of\n", |
| 73 | + " `tf.estimator.ModeKeys.TRAIN`, `EVALUATE`, `PREDICT`\n", |
| 74 | + " model_class: a class representing a TensorFlow model that has a __call__\n", |
| 75 | + " function. We assume here that this is a subclass of ResnetModel.\n", |
| 76 | + " resnet_size: A single integer for the size of the ResNet model.\n", |
| 77 | + " weight_decay: weight decay loss rate used to regularize learned variables.\n", |
| 78 | + " learning_rate_fn: function that returns the current learning rate given\n", |
| 79 | + " the current global_step\n", |
| 80 | + " momentum: momentum term used for optimization\n", |
| 81 | + " data_format: Input format ('channels_last', 'channels_first', or None).\n", |
| 82 | + " If set to None, the format is dependent on whether a GPU is available.\n", |
| 83 | + " resnet_version: Integer representing which version of the ResNet network to\n", |
| 84 | + " use. See README for details. Valid values: [1, 2]\n", |
| 85 | + " loss_scale: The factor to scale the loss for numerical stability. A detailed\n", |
| 86 | + " summary is present in the arg parser help text.\n", |
| 87 | + " loss_filter_fn: function that takes a string variable name and returns\n", |
| 88 | + " True if the var should be included in loss calculation, and False\n", |
| 89 | + " otherwise. If None, batch_normalization variables will be excluded\n", |
| 90 | + " from the loss.\n", |
| 91 | + " dtype: the TensorFlow dtype to use for calculations.\n", |
| 92 | + " Returns:\n", |
| 93 | + " EstimatorSpec parameterized according to the input params and the\n", |
| 94 | + " current mode.\n", |
| 95 | + " \"\"\"\n", |
| 96 | + "\n", |
| 97 | + " # Generate a summary node for the images\n", |
| 98 | + " tf.summary.image('images', features, max_outputs=6)\n", |
| 99 | + "\n", |
| 100 | + " features = tf.cast(features, dtype)\n", |
| 101 | + "\n", |
| 102 | + " model = model_class(resnet_size, data_format, resnet_version=resnet_version, dtype=dtype)\n", |
| 103 | + "\n", |
| 104 | + " logits = model(features, mode == tf.estimator.ModeKeys.TRAIN)\n", |
| 105 | + "\n", |
| 106 | + " # This acts as a no-op if the logits are already in fp32 (provided logits are\n", |
| 107 | + " # not a SparseTensor). If dtype is is low precision, logits must be cast to\n", |
| 108 | + " # fp32 for numerical stability.\n", |
| 109 | + " logits = tf.cast(logits, tf.float32)\n", |
| 110 | + "\n", |
| 111 | + " predictions = {\n", |
| 112 | + " 'classes': tf.argmax(logits, axis=1),\n", |
| 113 | + " 'probabilities': tf.nn.softmax(logits, name='softmax_tensor')\n", |
| 114 | + " }\n", |
| 115 | + "\n", |
| 116 | + " if mode == tf.estimator.ModeKeys.PREDICT:\n", |
| 117 | + " # Return the predictions and the specification for serving a SavedModel\n", |
| 118 | + " return tf.estimator.EstimatorSpec(\n", |
| 119 | + " mode=mode,\n", |
| 120 | + " predictions=predictions,\n", |
| 121 | + " export_outputs={\n", |
| 122 | + " 'predict': tf.estimator.export.PredictOutput(predictions)\n", |
| 123 | + " })\n", |
| 124 | + "\n", |
| 125 | + " # Calculate loss, which includes softmax cross entropy and L2 regularization.\n", |
| 126 | + " # cross entropy part\n", |
| 127 | + " cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels)\n", |
| 128 | + "\n", |
| 129 | + " # Create a tensor named cross_entropy for logging purposes.\n", |
| 130 | + " tf.identity(cross_entropy, name='cross_entropy')\n", |
| 131 | + " tf.summary.scalar('cross_entropy', cross_entropy)\n", |
| 132 | + " \n", |
| 133 | + " # L2 regularization part\n", |
| 134 | + " def exclude_batch_norm(name):\n", |
| 135 | + " return 'batch_normalization' not in name\n", |
| 136 | + " \n", |
| 137 | + " loss_filter_fn = loss_filter_fn or exclude_batch_norm\n", |
| 138 | + "\n", |
| 139 | + " # Add weight decay to the loss.\n", |
| 140 | + " l2_loss = weight_decay * tf.add_n(\n", |
| 141 | + " [tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables()\n", |
| 142 | + " if loss_filter_fn(v.name)])\n", |
| 143 | + " \n", |
| 144 | + " tf.summary.scalar('l2_loss', l2_loss)\n", |
| 145 | + " loss = cross_entropy + l2_loss\n", |
| 146 | + "\n", |
| 147 | + " if mode == tf.estimator.ModeKeys.TRAIN:\n", |
| 148 | + " global_step = tf.train.get_or_create_global_step()\n", |
| 149 | + "\n", |
| 150 | + " learning_rate = learning_rate_fn(global_step)\n", |
| 151 | + "\n", |
| 152 | + " # Create a tensor named learning_rate for logging purposes\n", |
| 153 | + " tf.identity(learning_rate, name='learning_rate')\n", |
| 154 | + " tf.summary.scalar('learning_rate', learning_rate)\n", |
| 155 | + "\n", |
| 156 | + " optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=momentum)\n", |
| 157 | + "\n", |
| 158 | + " if loss_scale != 1:\n", |
| 159 | + " # When computing fp16 gradients, often intermediate tensor values are\n", |
| 160 | + " # so small, they underflow to 0. To avoid this, we multiply the loss by\n", |
| 161 | + " # loss_scale to make these tensor values loss_scale times bigger.\n", |
| 162 | + " scaled_grad_vars = optimizer.compute_gradients(loss * loss_scale)\n", |
| 163 | + "\n", |
| 164 | + " # Once the gradient computation is complete we can scale the gradients\n", |
| 165 | + " # back to the correct scale before passing them to the optimizer.\n", |
| 166 | + " unscaled_grad_vars = [(grad / loss_scale, var) for grad, var in scaled_grad_vars]\n", |
| 167 | + " minimize_op = optimizer.apply_gradients(unscaled_grad_vars, global_step)\n", |
| 168 | + " else:\n", |
| 169 | + " minimize_op = optimizer.minimize(loss, global_step)\n", |
| 170 | + " \n", |
| 171 | + " update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)\n", |
| 172 | + " train_op = tf.group(minimize_op, update_ops)\n", |
| 173 | + " else:\n", |
| 174 | + " train_op = None\n", |
| 175 | + "\n", |
| 176 | + " \n", |
| 177 | + " if not tf.contrib.distribute.has_distribution_strategy():\n", |
| 178 | + " accuracy = tf.metrics.accuracy(tf.argmax(labels, axis=1), predictions['classes'])\n", |
| 179 | + " else:\n", |
| 180 | + " # Metrics are currently not compatible with distribution strategies during\n", |
| 181 | + " # training. This does not affect the overall performance of the model.\n", |
| 182 | + " accuracy = (tf.no_op(), tf.constant(0))\n", |
| 183 | + "\n", |
| 184 | + " metrics = {'accuracy': accuracy}\n", |
| 185 | + "\n", |
| 186 | + " # Create a tensor named train_accuracy for logging purposes\n", |
| 187 | + " tf.identity(accuracy[1], name='train_accuracy')\n", |
| 188 | + " tf.summary.scalar('train_accuracy', accuracy[1])\n", |
| 189 | + "\n", |
| 190 | + " return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, \n", |
| 191 | + " train_op=train_op, eval_metric_ops=metrics)" |
| 192 | + ] |
| 193 | + }, |
| 194 | + { |
| 195 | + "cell_type": "code", |
| 196 | + "execution_count": null, |
| 197 | + "metadata": {}, |
| 198 | + "outputs": [], |
| 199 | + "source": [ |
| 200 | + "def resnet_main(flags_obj, model_function, input_function, dataset_name, shape=None):\n", |
| 201 | + " \"\"\"Shared main loop for ResNet Models.\n", |
| 202 | + " Args:\n", |
| 203 | + " flags_obj: An object containing parsed flags. See define_resnet_flags()\n", |
| 204 | + " for details.\n", |
| 205 | + " model_function: the function that instantiates the Model and builds the\n", |
| 206 | + " ops for train/eval. This will be passed directly into the estimator.\n", |
| 207 | + " input_function: the function that processes the dataset and returns a\n", |
| 208 | + " dataset that the estimator can train on. This will be wrapped with\n", |
| 209 | + " all the relevant flags for running and passed to estimator.\n", |
| 210 | + " dataset_name: the name of the dataset for training and evaluation. This is\n", |
| 211 | + " used for logging purpose.\n", |
| 212 | + " shape: list of ints representing the shape of the images used for training.\n", |
| 213 | + " This is only used if flags_obj.export_dir is passed.\n", |
| 214 | + " \"\"\"\n", |
| 215 | + "\n", |
| 216 | + " # Using the Winograd non-fused algorithms provides a small performance boost.\n", |
| 217 | + " os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'\n", |
| 218 | + "\n", |
| 219 | + " # Create session config based on values of inter_op_parallelism_threads and\n", |
| 220 | + " # intra_op_parallelism_threads. Note that we default to having\n", |
| 221 | + " # allow_soft_placement = True, which is required for multi-GPU and not\n", |
| 222 | + " # harmful for other modes.\n", |
| 223 | + " session_config = tf.ConfigProto(\n", |
| 224 | + " inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,\n", |
| 225 | + " intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,\n", |
| 226 | + " allow_soft_placement=True)\n", |
| 227 | + "\n", |
| 228 | + " if flags_core.get_num_gpus(flags_obj) == 0:\n", |
| 229 | + " distribution = tf.contrib.distribute.OneDeviceStrategy('device:CPU:0')\n", |
| 230 | + " elif flags_core.get_num_gpus(flags_obj) == 1:\n", |
| 231 | + " distribution = tf.contrib.distribute.OneDeviceStrategy('device:GPU:0')\n", |
| 232 | + " else:\n", |
| 233 | + " distribution = tf.contrib.distribute.MirroredStrategy(num_gpus=flags_core.get_num_gpus(flags_obj))\n", |
| 234 | + "\n", |
| 235 | + " run_config = tf.estimator.RunConfig(train_distribute=distribution, session_config=session_config)\n", |
| 236 | + "\n", |
| 237 | + " classifier = tf.estimator.Estimator(model_fn=model_function, model_dir=flags_obj.model_dir, \n", |
| 238 | + " config=run_config,\n", |
| 239 | + " params={\n", |
| 240 | + " 'resnet_size': int(flags_obj.resnet_size),\n", |
| 241 | + " 'data_format': flags_obj.data_format,\n", |
| 242 | + " 'batch_size': flags_obj.batch_size,\n", |
| 243 | + " 'resnet_version': int(flags_obj.resnet_version),\n", |
| 244 | + " 'loss_scale': flags_core.get_loss_scale(flags_obj),\n", |
| 245 | + " 'dtype': flags_core.get_tf_dtype(flags_obj)\n", |
| 246 | + " })\n", |
| 247 | + "\n", |
| 248 | + " run_params = {\n", |
| 249 | + " 'batch_size': flags_obj.batch_size,\n", |
| 250 | + " 'dtype': flags_core.get_tf_dtype(flags_obj),\n", |
| 251 | + " 'resnet_size': flags_obj.resnet_size,\n", |
| 252 | + " 'resnet_version': flags_obj.resnet_version,\n", |
| 253 | + " 'synthetic_data': flags_obj.use_synthetic_data,\n", |
| 254 | + " 'train_epochs': flags_obj.train_epochs,\n", |
| 255 | + " }\n", |
| 256 | + " \n", |
| 257 | + " benchmark_logger = logger.config_benchmark_logger(flags_obj.benchmark_log_dir)\n", |
| 258 | + " benchmark_logger.log_run_info('resnet', dataset_name, run_params)\n", |
| 259 | + "\n", |
| 260 | + " train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,batch_size=flags_obj.batch_size,\n", |
| 261 | + " benchmark_log_dir=flags_obj.benchmark_log_dir)\n", |
| 262 | + "\n", |
| 263 | + " def input_fn_train():\n", |
| 264 | + " return input_function(is_training=True, data_dir=flags_obj.data_dir,\n", |
| 265 | + " batch_size=per_device_batch_size(flags_obj.batch_size, \n", |
| 266 | + " flags_core.get_num_gpus(flags_obj)),\n", |
| 267 | + " num_epochs=flags_obj.epochs_between_evals)\n", |
| 268 | + "\n", |
| 269 | + " def input_fn_eval():\n", |
| 270 | + " return input_function(is_training=False, data_dir=flags_obj.data_dir,\n", |
| 271 | + " batch_size=per_device_batch_size(flags_obj.batch_size, \n", |
| 272 | + " flags_core.get_num_gpus(flags_obj)),\n", |
| 273 | + " num_epochs=1)\n", |
| 274 | + "\n", |
| 275 | + " total_training_cycle = (flags_obj.train_epochs // flags_obj.epochs_between_evals)\n", |
| 276 | + " \n", |
| 277 | + " for cycle_index in range(total_training_cycle):\n", |
| 278 | + " tf.logging.info('Starting a training cycle: %d/%d', cycle_index, total_training_cycle)\n", |
| 279 | + "\n", |
| 280 | + " classifier.train(input_fn=input_fn_train, hooks=train_hooks,\n", |
| 281 | + " max_steps=flags_obj.max_train_steps)\n", |
| 282 | + "\n", |
| 283 | + " tf.logging.info('Starting to evaluate.')\n", |
| 284 | + "\n", |
| 285 | + " # flags_obj.max_train_steps is generally associated with testing and\n", |
| 286 | + " # profiling. As a result it is frequently called with synthetic data, which\n", |
| 287 | + " # will iterate forever. Passing steps=flags_obj.max_train_steps allows the\n", |
| 288 | + " # eval (which is generally unimportant in those circumstances) to terminate.\n", |
| 289 | + " # Note that eval will run for max_train_steps each loop, regardless of the\n", |
| 290 | + " # global_step count.\n", |
| 291 | + " eval_results = classifier.evaluate(input_fn=input_fn_eval,\n", |
| 292 | + " steps=flags_obj.max_train_steps)\n", |
| 293 | + "\n", |
| 294 | + " benchmark_logger.log_evaluation_result(eval_results)\n", |
| 295 | + "\n", |
| 296 | + " if model_helpers.past_stop_threshold(flags_obj.stop_threshold, eval_results['accuracy']):\n", |
| 297 | + " break\n", |
| 298 | + "\n", |
| 299 | + " if flags_obj.export_dir is not None:\n", |
| 300 | + " # Exports a saved model for the given classifier.\n", |
| 301 | + " input_receiver_fn = export.build_tensor_serving_input_receiver_fn(shape, batch_size=flags_obj.batch_size)\n", |
| 302 | + " classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)\n", |
| 303 | + "\n", |
| 304 | + "\n", |
| 305 | + "def define_resnet_flags(resnet_size_choices=None):\n", |
| 306 | + " \"\"\"Add flags and validators for ResNet.\"\"\"\n", |
| 307 | + " flags_core.define_base()\n", |
| 308 | + " flags_core.define_performance(num_parallel_calls=False)\n", |
| 309 | + " flags_core.define_image()\n", |
| 310 | + " flags_core.define_benchmark()\n", |
| 311 | + " flags.adopt_module_key_flags(flags_core)\n", |
| 312 | + "\n", |
| 313 | + " flags.DEFINE_enum(\n", |
| 314 | + " name='resnet_version', short_name='rv', default='2',\n", |
| 315 | + " enum_values=['1', '2'],\n", |
| 316 | + " help=flags_core.help_wrap(\n", |
| 317 | + " 'Version of ResNet. (1 or 2) See README.md for details.'))\n", |
| 318 | + "\n", |
| 319 | + "\n", |
| 320 | + " choice_kwargs = dict(\n", |
| 321 | + " name='resnet_size', short_name='rs', default='50',\n", |
| 322 | + " help=flags_core.help_wrap('The size of the ResNet model to use.'))\n", |
| 323 | + "\n", |
| 324 | + " if resnet_size_choices is None:\n", |
| 325 | + " flags.DEFINE_string(**choice_kwargs)\n", |
| 326 | + " else:\n", |
| 327 | + " flags.DEFINE_enum(enum_values=resnet_size_choices, **choice_kwargs)" |
| 328 | + ] |
| 329 | + } |
| 330 | + ], |
| 331 | + "metadata": { |
| 332 | + "kernelspec": { |
| 333 | + "display_name": "Python 3", |
| 334 | + "language": "python", |
| 335 | + "name": "python3" |
| 336 | + }, |
| 337 | + "language_info": { |
| 338 | + "codemirror_mode": { |
| 339 | + "name": "ipython", |
| 340 | + "version": 3 |
| 341 | + }, |
| 342 | + "file_extension": ".py", |
| 343 | + "mimetype": "text/x-python", |
| 344 | + "name": "python", |
| 345 | + "nbconvert_exporter": "python", |
| 346 | + "pygments_lexer": "ipython3", |
| 347 | + "version": "3.6.5" |
| 348 | + } |
| 349 | + }, |
| 350 | + "nbformat": 4, |
| 351 | + "nbformat_minor": 2 |
| 352 | +} |
0 commit comments