Skip to content

Commit 383d414

Browse files
committed
minor updates
1 parent 2fc637b commit 383d414

File tree

6 files changed

+25
-23
lines changed

6 files changed

+25
-23
lines changed

model.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import objective as obj_lib
2727

2828
import tensorflow.compat.v1 as tf
29+
from tensorflow.compat.v1 import estimator as tf_estimator
2930
import tensorflow.compat.v2 as tf2
3031

3132
FLAGS = flags.FLAGS
@@ -35,7 +36,7 @@ def build_model_fn(model, num_classes, num_train_examples):
3536
"""Build model function."""
3637
def model_fn(features, labels, mode, params=None):
3738
"""Build model and optimizer."""
38-
is_training = mode == tf.estimator.ModeKeys.TRAIN
39+
is_training = mode == tf_estimator.ModeKeys.TRAIN
3940

4041
# Check training mode.
4142
if FLAGS.train_mode == 'pretrain':
@@ -183,7 +184,7 @@ def scaffold_fn():
183184
else:
184185
scaffold_fn = None
185186

186-
return tf.estimator.tpu.TPUEstimatorSpec(
187+
return tf_estimator.tpu.TPUEstimatorSpec(
187188
mode=mode, train_op=train_op, loss=loss, scaffold_fn=scaffold_fn)
188189
else:
189190

@@ -215,7 +216,7 @@ def metric_fn(logits_sup, labels_sup, logits_con, labels_con, mask,
215216
tf.losses.get_regularization_loss()),
216217
}
217218

218-
return tf.estimator.tpu.TPUEstimatorSpec(
219+
return tf_estimator.tpu.TPUEstimatorSpec(
219220
mode=mode,
220221
loss=loss,
221222
eval_metrics=(metric_fn, metrics),

resnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ def _cross_replica_average(self, t):
6464
num_shards = tpu_function.get_tpu_context().number_of_shards
6565
return tf.tpu.cross_replica_sum(t) / tf.cast(num_shards, t.dtype)
6666

67-
def _moments(self, inputs, reduction_axes, keep_dims):
67+
def _moments(self, inputs, reduction_axes, keep_dims, mask=None):
6868
"""Compute the mean and variance: it overrides the original _moments."""
6969
shard_mean, shard_variance = super(BatchNormalization, self)._moments(
70-
inputs, reduction_axes, keep_dims=keep_dims)
70+
inputs, reduction_axes, keep_dims=keep_dims, mask=mask)
7171

7272
num_shards = tpu_function.get_tpu_context().number_of_shards
7373
if num_shards and num_shards > 1:

run.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import model_util as model_util
3232

3333
import tensorflow.compat.v1 as tf
34+
from tensorflow.compat.v1 import estimator as tf_estimator
3435
import tensorflow_datasets as tfds
3536
import tensorflow_hub as hub
3637

@@ -397,10 +398,10 @@ def main(argv):
397398
tf.config.experimental_connect_to_cluster(cluster)
398399
tf.tpu.experimental.initialize_tpu_system(cluster)
399400

400-
default_eval_mode = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1
401-
sliced_eval_mode = tf.estimator.tpu.InputPipelineConfig.SLICED
402-
run_config = tf.estimator.tpu.RunConfig(
403-
tpu_config=tf.estimator.tpu.TPUConfig(
401+
default_eval_mode = tf_estimator.tpu.InputPipelineConfig.PER_HOST_V1
402+
sliced_eval_mode = tf_estimator.tpu.InputPipelineConfig.SLICED
403+
run_config = tf_estimator.tpu.RunConfig(
404+
tpu_config=tf_estimator.tpu.TPUConfig(
404405
iterations_per_loop=checkpoint_steps,
405406
eval_training_input_configuration=sliced_eval_mode
406407
if FLAGS.use_tpu else default_eval_mode),
@@ -410,7 +411,7 @@ def main(argv):
410411
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
411412
master=FLAGS.master,
412413
cluster=cluster)
413-
estimator = tf.estimator.tpu.TPUEstimator(
414+
estimator = tf_estimator.tpu.TPUEstimator(
414415
model_lib.build_model_fn(model, num_classes, num_train_examples),
415416
config=run_config,
416417
train_batch_size=FLAGS.train_batch_size,

tf2/data.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,11 @@ def get_preprocess_fn(is_training, is_pretrain):
105105
test_crop = False
106106
else:
107107
test_crop = True
108+
color_jitter_strength = FLAGS.color_jitter_strength if is_pretrain else 0.
108109
return functools.partial(
109110
data_util.preprocess_image,
110111
height=FLAGS.image_size,
111112
width=FLAGS.image_size,
112113
is_training=is_training,
113-
color_distort=is_pretrain,
114+
color_jitter_strength=color_jitter_strength,
114115
test_crop=test_crop)

tf2/data_util.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,9 @@
1616
"""Data preprocessing and augmentation."""
1717

1818
import functools
19-
from absl import flags
2019

2120
import tensorflow.compat.v2 as tf
2221

23-
FLAGS = flags.FLAGS
24-
2522
CROP_PROPORTION = 0.875 # Standard for ImageNet.
2623

2724

@@ -446,7 +443,7 @@ def generate_selector(p, bsz):
446443
def preprocess_for_train(image,
447444
height,
448445
width,
449-
color_distort=True,
446+
color_jitter_strength=0.,
450447
crop=True,
451448
flip=True,
452449
impl='simclrv2'):
@@ -456,11 +453,12 @@ def preprocess_for_train(image,
456453
image: `Tensor` representing an image of arbitrary size.
457454
height: Height of output image.
458455
width: Width of output image.
459-
color_distort: Whether to apply the color distortion.
456+
color_jitter_strength: `float` between 0 and 1 indicating the color
457+
distortion strength, disable color distortion if not bigger than 0.
460458
crop: Whether to crop the image.
461459
flip: Whether or not to flip left and right of an image.
462460
impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
463-
version of random brightness.
461+
version of random brightness.
464462
465463
Returns:
466464
A preprocessed image `Tensor`.
@@ -469,8 +467,8 @@ def preprocess_for_train(image,
469467
image = random_crop_with_resize(image, height, width)
470468
if flip:
471469
image = tf.image.random_flip_left_right(image)
472-
if color_distort:
473-
image = random_color_jitter(image, strength=FLAGS.color_jitter_strength,
470+
if color_jitter_strength > 0:
471+
image = random_color_jitter(image, strength=color_jitter_strength,
474472
impl=impl)
475473
image = tf.reshape(image, [height, width, 3])
476474
image = tf.clip_by_value(image, 0., 1.)
@@ -497,15 +495,16 @@ def preprocess_for_eval(image, height, width, crop=True):
497495

498496

499497
def preprocess_image(image, height, width, is_training=False,
500-
color_distort=True, test_crop=True):
498+
color_jitter_strength=0., test_crop=True):
501499
"""Preprocesses the given image.
502500
503501
Args:
504502
image: `Tensor` representing an image of arbitrary size.
505503
height: Height of output image.
506504
width: Width of output image.
507505
is_training: `bool` for whether the preprocessing is for training.
508-
color_distort: whether to apply the color distortion.
506+
color_jitter_strength: `float` between 0 and 1 indicating the color
507+
distortion strength, disable color distortion if not bigger than 0.
509508
test_crop: whether or not to extract a central crop of the images
510509
(as for standard ImageNet evaluation) during the evaluation.
511510
@@ -514,6 +513,6 @@ def preprocess_image(image, height, width, is_training=False,
514513
"""
515514
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
516515
if is_training:
517-
return preprocess_for_train(image, height, width, color_distort)
516+
return preprocess_for_train(image, height, width, color_jitter_strength)
518517
else:
519518
return preprocess_for_eval(image, height, width, test_crop)

tf2/lars_optimizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
EETA_DEFAULT = 0.001
2323

2424

25-
class LARSOptimizer(tf.keras.optimizers.Optimizer):
25+
class LARSOptimizer(tf.keras.optimizers.legacy.Optimizer):
2626
"""Layer-wise Adaptive Rate Scaling for large batch training.
2727
2828
Introduced by "Large Batch Training of Convolutional Networks" by Y. You,

0 commit comments

Comments
 (0)