Skip to content

Commit 9080613

Browse files
committed
Fix deprecation warnings related to TF v1
1 parent cf9595a commit 9080613

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

keras/backend/tensorflow_backend.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ def eager_fn_wrapper(*args, **kwargs):
109109
return eager_fn_wrapper
110110

111111

112+
def _has_compat_v1():
113+
if hasattr(tf, 'compat') and hasattr(tf.compat, 'v1'):
114+
return True
115+
return False
116+
117+
112118
def get_uid(prefix=''):
113119
"""Provides a unique UID given a string prefix.
114120
@@ -2270,7 +2276,11 @@ def _fused_normalize_batch_in_training(x, gamma, beta, reduction_axes,
22702276
if beta.dtype != tf.float32:
22712277
beta = tf.cast(beta, tf.float32)
22722278

2273-
return tf.nn.fused_batch_norm(
2279+
if _has_compat_v1:
2280+
fused_batch_norm = tf.compat.v1.nn.fused_batch_norm
2281+
else:
2282+
fused_batch_norm = tf.nn.fused_batch_norm
2283+
return fused_batch_norm(
22742284
x,
22752285
gamma,
22762286
beta,
@@ -2373,7 +2383,12 @@ def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
23732383
if var.dtype != tf.float32:
23742384
var = tf.cast(var, tf.float32)
23752385

2376-
y, _, _ = tf.nn.fused_batch_norm(
2386+
if _has_compat_v1:
2387+
fused_batch_norm = tf.compat.v1.nn.fused_batch_norm
2388+
else:
2389+
fused_batch_norm = tf.nn.fused_batch_norm
2390+
2391+
y, _, _ = fused_batch_norm(
23772392
x,
23782393
gamma,
23792394
beta,

keras/optimizers.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,11 @@ def get(identifier):
853853
if K.backend() == 'tensorflow':
854854
# Wrap TF optimizer instances
855855
if tf.__version__.startswith('1.'):
856-
if isinstance(identifier, tf.train.Optimizer):
856+
try:
857+
TFOpt = tf.compat.v1.train.Optimizer
858+
except AttributeError:
859+
TFOpt = tf.train.Optimizer
860+
if isinstance(identifier, TFOpt):
857861
return TFOptimizer(identifier)
858862
elif isinstance(identifier, tf.keras.optimizers.Optimizer):
859863
return TFOptimizer(identifier)

0 commit comments

Comments
 (0)