From b4e8f13d3cc407fbb8f84efd7ba1fd464f52653a Mon Sep 17 00:00:00 2001 From: Clive Verghese Date: Wed, 16 Oct 2024 11:05:46 -0700 Subject: [PATCH] Replace tf.io.gfile with epath.Path. PiperOrigin-RevId: 686570636 --- .../profile_plugin.py | 46 +++++++++---------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/plugin/tensorboard_plugin_profile/profile_plugin.py b/plugin/tensorboard_plugin_profile/profile_plugin.py index cf981445..7a6888ca 100644 --- a/plugin/tensorboard_plugin_profile/profile_plugin.py +++ b/plugin/tensorboard_plugin_profile/profile_plugin.py @@ -27,6 +27,7 @@ import threading from typing import Any, List, TypedDict +from etils import epath import six import tensorflow.compat.v2 as tf from werkzeug import wrappers @@ -193,7 +194,7 @@ def _get_hosts(filenames: list[str]) -> set[str]: return hosts -def _get_tools(filenames: list[str], profile_run_dir: str) -> set[str]: +def _get_tools(filenames: list[Any], profile_run_dir: str) -> set[str]: """Parses a list of filenames and returns the set of tools. If xplane is present in the repository, add tools that can be generated by @@ -210,7 +211,7 @@ def _get_tools(filenames: list[str], profile_run_dir: str) -> set[str]: found = set() xplane_filenames = [] for name in filenames: - _, tool = _parse_filename(name) + _, tool = _parse_filename(name.name) if tool == 'xplane': xplane_filenames.append(os.path.join(profile_run_dir, name)) continue @@ -546,10 +547,11 @@ def _run_host_impl( tool_pattern = make_filename('*', tool) filenames = [] try: - filenames = tf.io.gfile.glob(os.path.join(run_dir, tool_pattern)) - except tf.errors.OpError as e: + path = epath.Path(run_dir) + filenames = path.glob(tool_pattern) + except OSError as e: logger.warning('Cannot read asset directory: %s, OpError %s', run_dir, e) - filenames = [os.path.basename(f) for f in filenames] + filenames = [os.fspath(os.path.basename(f)) for f in filenames] return [{'hostname': host} for host in filenames_to_hosts(filenames, tool)] @@ -659,8 +661,9 @@ def data_impl( if host == ALL_HOSTS: file_pattern = make_filename('*', 'xplane') try: - asset_paths = tf.io.gfile.glob(os.path.join(run_dir, file_pattern)) - except tf.errors.OpError as e: + path = epath.Path(run_dir) + asset_paths = path.glob(file_pattern) + except OSError as e: logger.warning('Cannot read asset directory: %s, OpError %s', run_dir, e) raise IOError( @@ -684,12 +687,10 @@ def data_impl( raw_data = None try: - with tf.io.gfile.GFile(asset_path, 'rb') as f: - raw_data = f.read() - except tf.errors.NotFoundError: - logger.warning('Asset path %s not found', asset_path) - except tf.errors.OpError as e: - logger.warning("Couldn't read asset path: %s, OpError %s", asset_path, e) + path = epath.Path(asset_path) + raw_data = path.read_bytes() + except OSError as e: + logger.warning("Couldn't read asset path: %s, Error: %s", asset_path, e) if raw_data is None: return None, content_type, None @@ -723,6 +724,7 @@ def capture_route(self, request: wrappers.Request) -> wrappers.Response: def capture_route_impl(self, request: wrappers.Request) -> wrappers.Response: """Runs the client trace for capturing profiling information.""" + service_addr = request.args.get('service_addr') duration = int(request.args.get('duration', '1000')) is_tpu_name = request.args.get('is_tpu_name') == 'true' @@ -784,12 +786,6 @@ def capture_route_impl(self, request: wrappers.Request) -> wrappers.Response: {'result': 'Capture profile successfully. Please refresh.'}, 'application/json', ) - except tf.errors.UnavailableError: - return respond( - {'error': 'empty trace result.'}, - 'application/json', - code=200, - ) except Exception as e: # pylint: disable=broad-except return respond( {'error': str(e)}, @@ -841,7 +837,7 @@ def _run_dir(self, run: str) -> str: if not tb_run_name: tb_run_name = '.' tb_run_directory = _tb_run_directory(self.logdir, tb_run_name) - if not tf.io.gfile.isdir(tb_run_directory): + if not epath.Path(tb_run_directory).is_dir(): raise RuntimeError('No matching run directory for run %s' % run) plugin_directory = plugin_asset_util.PluginDirectory( @@ -904,7 +900,7 @@ def generate_runs(self) -> Iterator[str]: # backwards compatible with previously profile plugin behavior. Note that we # check if logdir is a directory to handle case where it's actually a # multipart directory spec, which this plugin does not support. - if '.' not in tb_runs and tf.io.gfile.isdir(self.logdir): + if '.' not in tb_runs and epath.Path(self.logdir).is_dir(): tb_runs.append('.') tb_run_names_to_dirs = { run: _tb_run_directory(self.logdir, run) for run in tb_runs @@ -923,17 +919,17 @@ def generate_runs(self) -> Iterator[str]: else: frontend_run = os.path.join(tb_run_name, profile_run) profile_run_dir = os.path.join(tb_plugin_dir, profile_run) - if tf.io.gfile.isdir(profile_run_dir): + if epath.Path(profile_run_dir).is_dir(): self._run_to_profile_run_dir[frontend_run] = profile_run_dir yield frontend_run def generate_tools_of_run(self, run: str) -> Iterator[str]: """Generate a list of tools given a certain run.""" profile_run_dir = self._run_to_profile_run_dir[run] - if tf.io.gfile.isdir(profile_run_dir): + if epath.Path(profile_run_dir).is_dir(): try: - filenames = tf.io.gfile.listdir(profile_run_dir) - except tf.errors.NotFoundError as e: + filenames = epath.Path(profile_run_dir).iterdir() + except OSError as e: logger.warning('Cannot read asset directory: %s, NotFoundError %s', profile_run_dir, e) filenames = []