diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tf_profiler.py b/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tf_profiler.py index 81b43145b3..e532db4222 100644 --- a/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tf_profiler.py +++ b/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tf_profiler.py @@ -17,11 +17,15 @@ """A plugin to handle remote tensoflow profiler sessions for Vertex AI.""" -from google.cloud.aiplatform.training_utils.cloud_profiler import cloud_profiler_utils +from google.cloud.aiplatform.training_utils.cloud_profiler import ( + cloud_profiler_utils, +) try: import tensorflow as tf - from tensorboard_plugin_profile.profile_plugin import ProfilePlugin + from tensorboard_plugin_profile.profile_plugin import ( + ProfilePlugin, + ) except ImportError as err: raise ImportError(cloud_profiler_utils.import_error_msg) from err @@ -36,10 +40,14 @@ import tensorboard.plugins.base_plugin as tensorboard_base_plugin from werkzeug import Response -from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import profile_uploader +from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import ( + profile_uploader, +) from google.cloud.aiplatform.training_utils import environment_variables from google.cloud.aiplatform.training_utils.cloud_profiler import wsgi_types -from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import base_plugin +from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import ( + base_plugin, +) from google.cloud.aiplatform.training_utils.cloud_profiler.plugins.tensorflow import ( tensorboard_api, ) @@ -68,8 +76,7 @@ def _get_tf_versioning() -> Optional[Version]: versioning = version.split(".") if len(versioning) != 3: return - - return Version(int(versioning[0]), int(versioning[1]), int(versioning[2])) + return Version(int(versioning[0]), int(versioning[1]), versioning[2]) def _is_compatible_version(version: Version) -> bool: @@ -228,7 +235,7 @@ def warn_tensorboard_env_var(var_name: str): Required. The name of the missing environment variable. """ logging.warning( - f"Environment variable `{var_name}` must be set. " + _BASE_TB_ENV_WARNING + "Environment variable `%s` must be set. %s", var_name, _BASE_TB_ENV_WARNING ) diff --git a/tests/unit/aiplatform/test_cloud_profiler.py b/tests/unit/aiplatform/test_cloud_profiler.py index 388405d034..b686419361 100644 --- a/tests/unit/aiplatform/test_cloud_profiler.py +++ b/tests/unit/aiplatform/test_cloud_profiler.py @@ -31,8 +31,12 @@ from google.api_core import exceptions from google.cloud import aiplatform from google.cloud.aiplatform import training_utils -from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import profile_uploader -from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import base_plugin +from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import ( + profile_uploader, +) +from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import ( + base_plugin, +) from google.cloud.aiplatform.training_utils.cloud_profiler.plugins.tensorflow import ( tf_profiler, ) @@ -175,15 +179,21 @@ def tf_import_mock(name, *args, **kwargs): def testCanInitializeTFVersion(self): import tensorflow - with mock.patch.object(tensorflow, "__version__", return_value="1.2.3.4"): + with mock.patch.object(tensorflow, "__version__", "1.2.3.4"): assert not TFProfiler.can_initialize() def testCanInitializeOldTFVersion(self): import tensorflow - with mock.patch.object(tensorflow, "__version__", return_value="2.3.0"): + with mock.patch.object(tensorflow, "__version__", "2.3.0"): assert not TFProfiler.can_initialize() + def testCanInitializeRcTFVersion(self): + import tensorflow as tf + + with mock.patch.object(tf, "__version__", "2.4.0-rc2"): + assert TFProfiler.can_initialize() + def testCanInitializeNoProfilePlugin(self): orig_find_spec = importlib.util.find_spec