diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py index ff801b66587..221353d87cd 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -1526,6 +1526,10 @@ def __init__(self): "'merge_repeated' argument and behaves as if merge_repeated=False. " "This call site specifies something other than " "merge_repeated=False, so it was converted to compat.v1."), + "tf.nn.dilation2d": functools.partial( + _add_argument_transformer, + arg_name="data_format", + arg_value_ast=ast.Str("NHWC")), "tf.nn.erosion2d": functools.partial( _add_argument_transformer, arg_name="data_format", diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py index 58653d8fab2..4464a2aed63 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py @@ -2069,6 +2069,12 @@ def testNnErosion2d(self): _, _, _, new_text = self._upgrade(text) self.assertEqual(new_text, expected_text) + def testNnDilation2d(self): + text = "tf.nn.dilation2d(v, k, s, r, p)" + expected_text = "tf.nn.dilation2d(v, k, s, r, p, data_format='NHWC')" + _, _, _, new_text = self._upgrade(text) + self.assertEqual(new_text, expected_text) + def testPywrapTensorflowWarning(self): text = "tf.pywrap_tensorflow.foo()" expected = "tf.pywrap_tensorflow.foo()"