Skip to content

Commit

Permalink
Replace tf.io.gfile with epath.Path.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686570636
  • Loading branch information
cliveverghese authored and copybara-github committed Oct 29, 2024
1 parent 2587a40 commit b4e8f13
Showing 1 changed file with 21 additions and 25 deletions.
46 changes: 21 additions & 25 deletions plugin/tensorboard_plugin_profile/profile_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import threading
from typing import Any, List, TypedDict

from etils import epath
import six
import tensorflow.compat.v2 as tf
from werkzeug import wrappers
Expand Down Expand Up @@ -193,7 +194,7 @@ def _get_hosts(filenames: list[str]) -> set[str]:
return hosts


def _get_tools(filenames: list[str], profile_run_dir: str) -> set[str]:
def _get_tools(filenames: list[Any], profile_run_dir: str) -> set[str]:
"""Parses a list of filenames and returns the set of tools.
If xplane is present in the repository, add tools that can be generated by
Expand All @@ -210,7 +211,7 @@ def _get_tools(filenames: list[str], profile_run_dir: str) -> set[str]:
found = set()
xplane_filenames = []
for name in filenames:
_, tool = _parse_filename(name)
_, tool = _parse_filename(name.name)
if tool == 'xplane':
xplane_filenames.append(os.path.join(profile_run_dir, name))
continue
Expand Down Expand Up @@ -546,10 +547,11 @@ def _run_host_impl(
tool_pattern = make_filename('*', tool)
filenames = []
try:
filenames = tf.io.gfile.glob(os.path.join(run_dir, tool_pattern))
except tf.errors.OpError as e:
path = epath.Path(run_dir)
filenames = path.glob(tool_pattern)
except OSError as e:
logger.warning('Cannot read asset directory: %s, OpError %s', run_dir, e)
filenames = [os.path.basename(f) for f in filenames]
filenames = [os.fspath(os.path.basename(f)) for f in filenames]

return [{'hostname': host} for host in filenames_to_hosts(filenames, tool)]

Expand Down Expand Up @@ -659,8 +661,9 @@ def data_impl(
if host == ALL_HOSTS:
file_pattern = make_filename('*', 'xplane')
try:
asset_paths = tf.io.gfile.glob(os.path.join(run_dir, file_pattern))
except tf.errors.OpError as e:
path = epath.Path(run_dir)
asset_paths = path.glob(file_pattern)
except OSError as e:
logger.warning('Cannot read asset directory: %s, OpError %s', run_dir,
e)
raise IOError(
Expand All @@ -684,12 +687,10 @@ def data_impl(

raw_data = None
try:
with tf.io.gfile.GFile(asset_path, 'rb') as f:
raw_data = f.read()
except tf.errors.NotFoundError:
logger.warning('Asset path %s not found', asset_path)
except tf.errors.OpError as e:
logger.warning("Couldn't read asset path: %s, OpError %s", asset_path, e)
path = epath.Path(asset_path)
raw_data = path.read_bytes()
except OSError as e:
logger.warning("Couldn't read asset path: %s, Error: %s", asset_path, e)

if raw_data is None:
return None, content_type, None
Expand Down Expand Up @@ -723,6 +724,7 @@ def capture_route(self, request: wrappers.Request) -> wrappers.Response:

def capture_route_impl(self, request: wrappers.Request) -> wrappers.Response:
"""Runs the client trace for capturing profiling information."""

service_addr = request.args.get('service_addr')
duration = int(request.args.get('duration', '1000'))
is_tpu_name = request.args.get('is_tpu_name') == 'true'
Expand Down Expand Up @@ -784,12 +786,6 @@ def capture_route_impl(self, request: wrappers.Request) -> wrappers.Response:
{'result': 'Capture profile successfully. Please refresh.'},
'application/json',
)
except tf.errors.UnavailableError:
return respond(
{'error': 'empty trace result.'},
'application/json',
code=200,
)
except Exception as e: # pylint: disable=broad-except
return respond(
{'error': str(e)},
Expand Down Expand Up @@ -841,7 +837,7 @@ def _run_dir(self, run: str) -> str:
if not tb_run_name:
tb_run_name = '.'
tb_run_directory = _tb_run_directory(self.logdir, tb_run_name)
if not tf.io.gfile.isdir(tb_run_directory):
if not epath.Path(tb_run_directory).is_dir():
raise RuntimeError('No matching run directory for run %s' % run)

plugin_directory = plugin_asset_util.PluginDirectory(
Expand Down Expand Up @@ -904,7 +900,7 @@ def generate_runs(self) -> Iterator[str]:
# backwards compatible with previously profile plugin behavior. Note that we
# check if logdir is a directory to handle case where it's actually a
# multipart directory spec, which this plugin does not support.
if '.' not in tb_runs and tf.io.gfile.isdir(self.logdir):
if '.' not in tb_runs and epath.Path(self.logdir).is_dir():
tb_runs.append('.')
tb_run_names_to_dirs = {
run: _tb_run_directory(self.logdir, run) for run in tb_runs
Expand All @@ -923,17 +919,17 @@ def generate_runs(self) -> Iterator[str]:
else:
frontend_run = os.path.join(tb_run_name, profile_run)
profile_run_dir = os.path.join(tb_plugin_dir, profile_run)
if tf.io.gfile.isdir(profile_run_dir):
if epath.Path(profile_run_dir).is_dir():
self._run_to_profile_run_dir[frontend_run] = profile_run_dir
yield frontend_run

def generate_tools_of_run(self, run: str) -> Iterator[str]:
"""Generate a list of tools given a certain run."""
profile_run_dir = self._run_to_profile_run_dir[run]
if tf.io.gfile.isdir(profile_run_dir):
if epath.Path(profile_run_dir).is_dir():
try:
filenames = tf.io.gfile.listdir(profile_run_dir)
except tf.errors.NotFoundError as e:
filenames = epath.Path(profile_run_dir).iterdir()
except OSError as e:
logger.warning('Cannot read asset directory: %s, NotFoundError %s',
profile_run_dir, e)
filenames = []
Expand Down

0 comments on commit b4e8f13

Please sign in to comment.