Skip to content

Commit 9e64388

Browse files
authored
Makes axlearn/cloud/ use file_system. (apple#998)
* Makes bastion.py use file_system. This is a first step towards removing the tf.io.gfile dependency. * Adds testing for file_system.readfile. * Fixes pytype. * Makes axlearn/cloud use file_system instead of gfile.
1 parent 5fba4ce commit 9e64388

File tree

10 files changed

+93
-82
lines changed

10 files changed

+93
-82
lines changed

axlearn/cloud/common/bastion.py

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@
6666
from typing import IO, Any, Optional, Union
6767

6868
from absl import flags, logging
69-
from tensorflow import errors as tf_errors
70-
from tensorflow import io as tf_io
7169
from tensorflow import nest as tf_nest
7270

7371
# tensorflow_io import is necessary for tf_io to understand s3:// scheme.
@@ -92,6 +90,9 @@
9290
config_for_function,
9391
maybe_instantiate,
9492
)
93+
from axlearn.common.file_system import NotFoundError, copy, exists, isdir, listdir, makedirs
94+
from axlearn.common.file_system import open as fs_open
95+
from axlearn.common.file_system import readfile, remove, rmtree
9596
from axlearn.common.utils import Nested
9697

9798
_LATEST_BASTION_VERSION = 1 # Determines job schema (see JobSpec).
@@ -113,34 +114,33 @@ def bastion_job_flags(flag_values: flags.FlagValues = FLAGS):
113114

114115

115116
# The following functions, `_download`, `_readfile`, `_listdir`, and `_remove`, can be patched to
116-
# support alternative storages that cannot be accessed via gfile.
117+
# support alternative storages that cannot be accessed via file_system.
117118
#
118119
# TODO(ruoming): refactor them to a `BastionDirStorage` class.
119120
def _download(path: str, local_file: str):
120-
tf_io.gfile.copy(path, local_file, overwrite=True)
121+
copy(path, local_file, overwrite=True)
121122

122123

123124
def _readfile(path: str) -> str:
124-
with tf_io.gfile.GFile(path, mode="r") as f:
125-
return f.read()
125+
return readfile(path)
126126

127127

128128
def _listdir(path: str) -> list[str]:
129-
"""Wraps tf_io.gfile.listdir by returning empty list if dir is not found."""
129+
"""Wraps file_system.listdir by returning empty list if dir is not found."""
130130
try:
131-
return tf_io.gfile.listdir(path)
132-
except tf_errors.NotFoundError:
131+
return listdir(path)
132+
except NotFoundError:
133133
return []
134134

135135

136136
def _remove(path: str):
137-
"""Wraps tf_io.gfile.remove by catching not found errors."""
137+
"""Wraps file_system.remove by catching not found errors."""
138138
try:
139-
if tf_io.gfile.isdir(path):
140-
tf_io.gfile.rmtree(path)
139+
if isdir(path):
140+
rmtree(path)
141141
else:
142-
tf_io.gfile.remove(path)
143-
except tf_errors.NotFoundError:
142+
remove(path)
143+
except NotFoundError:
144144
pass
145145

146146

@@ -357,7 +357,7 @@ def _upload_jobspec(spec: JobSpec, *, remote_dir: str, local_dir: str = _JOB_DIR
357357
local_file = os.path.join(local_dir, spec.name)
358358
remote_file = os.path.join(remote_dir, spec.name)
359359
serialize_jobspec(spec, local_file)
360-
tf_io.gfile.copy(local_file, remote_file, overwrite=True)
360+
copy(local_file, remote_file, overwrite=True)
361361

362362

363363
@dataclasses.dataclass
@@ -440,7 +440,7 @@ def _download_job_state(job_name: str, *, remote_dir: str) -> JobState:
440440
"""Loads job state from gs path."""
441441
remote_file = os.path.join(remote_dir, job_name)
442442
try:
443-
# Note: tf_io.gfile.GFile seems to hit libcurl errors with ThreadPoolExecutor.
443+
# Note: GFile seems to hit libcurl errors with ThreadPoolExecutor.
444444
contents = _readfile(remote_file)
445445
try:
446446
state = json.loads(contents)
@@ -449,7 +449,7 @@ def _download_job_state(job_name: str, *, remote_dir: str) -> JobState:
449449
state = dict(status=contents)
450450
state["status"] = JobStatus[state["status"].strip().upper()]
451451
return JobState(**state)
452-
except tf_errors.NotFoundError:
452+
except NotFoundError:
453453
# No job state, defaults to PENDING.
454454
return JobState(status=JobStatus.PENDING)
455455

@@ -458,7 +458,7 @@ def _upload_job_state(job_name: str, state: JobState, *, remote_dir: str, verbos
458458
"""Uploads job state to gs path."""
459459
remote_file = os.path.join(remote_dir, job_name)
460460
logging.log_if(logging.INFO, "Writing %s to %s.", verbose, state.status.name, remote_file)
461-
with tf_io.gfile.GFile(remote_file, mode="w") as f:
461+
with fs_open(remote_file, mode="w") as f:
462462
json.dump(dataclasses.asdict(state), f)
463463

464464

@@ -471,7 +471,7 @@ def _start_command(job: Job, *, remote_log_dir: str, env_vars: dict):
471471
local_log = os.path.join(_LOG_DIR, job.spec.name)
472472
try:
473473
_download(remote_log, local_log)
474-
except tf_errors.NotFoundError:
474+
except NotFoundError:
475475
pass
476476
# Pipe all outputs to the local _LOG_DIR.
477477
job.command_proc = _piped_popen(
@@ -648,9 +648,8 @@ def _load_runtime_options(bastion_dir: str) -> dict[str, Any]:
648648
"""Loads runtime option(s) from file, or returns {} on failure."""
649649
flag_file = os.path.join(bastion_dir, "runtime_options")
650650
try:
651-
with tf_io.gfile.GFile(flag_file, "r") as f:
652-
return json.load(f)
653-
except (tf_errors.NotFoundError, json.JSONDecodeError) as e:
651+
return json.loads(readfile(flag_file))
652+
except (NotFoundError, json.JSONDecodeError) as e:
654653
logging.warning("Failed to load runtime options: %s", e)
655654
return {}
656655

@@ -660,7 +659,7 @@ def set_runtime_options(bastion_dir: str, **kwargs) -> Nested[Any]:
660659
runtime_options = _load_runtime_options(bastion_dir)
661660
runtime_options = merge(runtime_options, kwargs)
662661
flag_file = os.path.join(bastion_dir, "runtime_options")
663-
with tf_io.gfile.GFile(flag_file, "w") as f:
662+
with fs_open(flag_file, "w") as f:
664663
json.dump(runtime_options, f)
665664
logging.info("Updated runtime options: %s", runtime_options)
666665
return runtime_options
@@ -699,11 +698,11 @@ def __init__(self, cfg: Config):
699698
self._job_dir = os.path.join(self._output_dir, "jobs")
700699
# Remote history dir. Ensure trailing slash.
701700
self._job_history_dir = os.path.join(self._output_dir, "history", "jobs")
702-
tf_io.gfile.makedirs(self._job_history_dir)
701+
makedirs(self._job_history_dir)
703702
self._project_history_dir = os.path.join(self._output_dir, "history", "projects")
704-
tf_io.gfile.makedirs(self._project_history_dir)
703+
makedirs(self._project_history_dir)
705704
self._scheduler_history_dir = os.path.join(self._output_dir, "history", "scheduler")
706-
tf_io.gfile.makedirs(self._scheduler_history_dir)
705+
makedirs(self._scheduler_history_dir)
707706
# Mapping from project_id to previous job verdicts.
708707
self._project_history_previous_verdicts = {}
709708
# Jobs that have fully completed.
@@ -733,7 +732,7 @@ def __init__(self, cfg: Config):
733732
self._event_publisher = maybe_instantiate(cfg.event_publisher)
734733

735734
def _append_to_job_history(self, job: Job, *, msg: str, state: JobLifecycleState):
736-
with tf_io.gfile.GFile(os.path.join(self._job_history_dir, f"{job.spec.name}"), "a") as f:
735+
with fs_open(os.path.join(self._job_history_dir, f"{job.spec.name}"), "a") as f:
737736
curr_time = datetime.now(timezone.utc).strftime("%m%d %H:%M:%S")
738737
f.write(f"{curr_time} {msg}\n")
739738
# Publish event into queue.
@@ -751,7 +750,7 @@ def _append_to_history(
751750
self, jobs: dict[str, JobMetadata], schedule_results: BaseScheduler.ScheduleResults
752751
):
753752
now = datetime.now(timezone.utc)
754-
with tf_io.gfile.GFile(
753+
with fs_open(
755754
os.path.join(self._scheduler_history_dir, now.strftime("%Y%m%d-%H%M%S")), "a"
756755
) as f:
757756
for job_id, verdict in schedule_results.job_verdicts.items():
@@ -794,8 +793,8 @@ def resource_str(resource_map: ResourceMap) -> str:
794793
)
795794

796795
project_dir = os.path.join(self._project_history_dir, project_id)
797-
tf_io.gfile.makedirs(project_dir)
798-
with tf_io.gfile.GFile(os.path.join(project_dir, now.strftime("%Y%m%d")), "a") as f:
796+
makedirs(project_dir)
797+
with fs_open(os.path.join(project_dir, now.strftime("%Y%m%d")), "a") as f:
799798
curr_time = now.strftime("%m%d %H:%M:%S")
800799
f.write(f"{curr_time}\n")
801800
f.write(f"Effective limits: {resource_str(limits)}\n")
@@ -838,7 +837,7 @@ def _wait_and_close_proc(self, proc: _PipedProcess, kill: bool = False):
838837
proc.fd.close()
839838
# Upload outputs to log dir.
840839
_catch_with_error_log(
841-
tf_io.gfile.copy,
840+
copy,
842841
proc.fd.name,
843842
os.path.join(self._log_dir, os.path.basename(proc.fd.name)),
844843
overwrite=True,
@@ -1301,7 +1300,7 @@ def list_jobs(self):
13011300
def cancel_job(self, job_name: str):
13021301
try:
13031302
jobspec = os.path.join(self.active_job_dir, job_name)
1304-
if not tf_io.gfile.exists(jobspec):
1303+
if not exists(jobspec):
13051304
raise ValueError(f"Unable to locate jobspec {jobspec}")
13061305
_upload_job_state(
13071306
job_name,
@@ -1310,7 +1309,7 @@ def cancel_job(self, job_name: str):
13101309
)
13111310
logging.info("Job %s is cancelling.", job_name)
13121311
# Poll for jobspec to be removed.
1313-
while tf_io.gfile.exists(jobspec):
1312+
while exists(jobspec):
13141313
logging.info("Waiting for job to stop (which usually takes a few minutes)...")
13151314
time.sleep(10)
13161315
logging.info("Job is stopped.")
@@ -1321,15 +1320,15 @@ def submit_job(self, job_name: str, *, job_spec_file: str):
13211320
if not is_valid_job_name(job_name):
13221321
raise ValueError(f"{job_name} is not a valid job name.")
13231322
dst = os.path.join(self.active_job_dir, job_name)
1324-
if tf_io.gfile.exists(dst):
1323+
if exists(dst):
13251324
logging.info("\n\nNote: Job is already running. To restart it, cancel the job first.\n")
13261325
else:
13271326
# Upload the job for bastion to pickup.
1328-
tf_io.gfile.copy(job_spec_file, dst)
1327+
copy(job_spec_file, dst)
13291328

13301329
def get_job(self, job_name: str) -> JobSpec:
13311330
job_path = os.path.join(self.active_job_dir, job_name)
1332-
if not tf_io.gfile.exists(job_path):
1331+
if not exists(job_path):
13331332
raise ValueError(f"Unable to locate jobspec {job_path}")
13341333

13351334
with tempfile.TemporaryDirectory() as tmpdir:
@@ -1338,13 +1337,13 @@ def get_job(self, job_name: str) -> JobSpec:
13381337

13391338
def update_job(self, job_name: str, *, job_spec: JobSpec) -> JobSpec:
13401339
dst = os.path.join(self.active_job_dir, job_name)
1341-
if not tf_io.gfile.exists(dst):
1340+
if not exists(dst):
13421341
raise ValueError(f"Unable to locate jobspec {dst}")
13431342

13441343
with tempfile.NamedTemporaryFile("w") as f:
13451344
serialize_jobspec(job_spec, f)
13461345
# Upload the job for bastion to pickup.
1347-
tf_io.gfile.copy(f.name, dst, overwrite=True)
1346+
copy(f.name, dst, overwrite=True)
13481347
logging.info("Job %s is updating.", job_name)
13491348

13501349
return job_spec

axlearn/cloud/common/bastion_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,7 @@ def test_pending(self, popen_spec, user_state_exists):
889889
send_signal=mock.DEFAULT,
890890
)
891891
patch_tfio = mock.patch.multiple(
892-
f"{bastion.__name__}.tf_io.gfile",
892+
f"{bastion.__name__}",
893893
exists=mock.Mock(return_value=user_state_exists),
894894
copy=mock.DEFAULT,
895895
remove=mock.DEFAULT,
@@ -1003,7 +1003,7 @@ def mock_tfio_exists(f):
10031003
_upload_job_state=mock.DEFAULT,
10041004
)
10051005
patch_tfio = mock.patch.multiple(
1006-
f"{bastion.__name__}.tf_io.gfile",
1006+
f"{bastion.__name__}",
10071007
exists=mock.MagicMock(side_effect=mock_tfio_exists),
10081008
copy=mock.DEFAULT,
10091009
)
@@ -1224,7 +1224,7 @@ def mock_proc(cmd, **kwargs):
12241224
send_signal=mock.DEFAULT,
12251225
)
12261226
patch_tfio = mock.patch.multiple(
1227-
f"{bastion.__name__}.tf_io.gfile",
1227+
f"{bastion.__name__}",
12281228
exists=mock.DEFAULT,
12291229
copy=mock.DEFAULT,
12301230
remove=mock.DEFAULT,
@@ -1378,7 +1378,7 @@ def test_gc_jobs(self):
13781378
rescheduled = ["rescheduled"]
13791379

13801380
patch_tfio = mock.patch.multiple(
1381-
f"{bastion.__name__}.tf_io.gfile",
1381+
f"{bastion.__name__}",
13821382
remove=mock.DEFAULT,
13831383
)
13841384
with self._patch_bastion() as mock_bastion, patch_tfio as mock_tfio:
@@ -1584,7 +1584,7 @@ def test_submit_job(self, job_name, spec_exists):
15841584
self.assertEqual("test-dir/jobs/states", bastion_dir.job_states_dir)
15851585
self.assertEqual("test-dir/jobs/user_states", bastion_dir.user_states_dir)
15861586
patch_tfio = mock.patch.multiple(
1587-
f"{bastion.__name__}.tf_io.gfile",
1587+
f"{bastion.__name__}",
15881588
exists=mock.MagicMock(return_value=spec_exists),
15891589
copy=mock.DEFAULT,
15901590
)
@@ -1609,7 +1609,7 @@ def test_delete(self, spec_exists):
16091609
bastion.BastionDirectory.default_config().set(root_dir="test-dir").instantiate()
16101610
)
16111611
patch_tfio = mock.patch.multiple(
1612-
f"{bastion.__name__}.tf_io.gfile",
1612+
f"{bastion.__name__}",
16131613
exists=mock.MagicMock(side_effect=[spec_exists, False]),
16141614
copy=mock.DEFAULT,
16151615
)
@@ -1636,7 +1636,7 @@ def test_get(self, spec_exists):
16361636
)
16371637

16381638
patch_tfio = mock.patch.multiple(
1639-
f"{bastion.__name__}.tf_io.gfile",
1639+
f"{bastion.__name__}",
16401640
exists=mock.MagicMock(return_value=spec_exists),
16411641
copy=mock.DEFAULT,
16421642
)
@@ -1670,7 +1670,7 @@ def test_update(self, spec_exists):
16701670
)
16711671

16721672
patch_tfio = mock.patch.multiple(
1673-
f"{bastion.__name__}.tf_io.gfile",
1673+
f"{bastion.__name__}",
16741674
exists=mock.MagicMock(return_value=spec_exists),
16751675
copy=mock.DEFAULT,
16761676
)

axlearn/cloud/common/bundler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151

5252
import prefixed
5353
from absl import app, flags, logging
54-
from tensorflow import io as tf_io
5554

5655
from axlearn.cloud.common import config
5756
from axlearn.cloud.common.docker import build as docker_build
@@ -68,6 +67,7 @@
6867
running_from_source,
6968
)
7069
from axlearn.common.config import REQUIRED, Configurable, Required, config_class
70+
from axlearn.common.file_system import copy, exists, makedirs
7171

7272
BUNDLE_EXCLUDE = [
7373
# Each entry below specifies a subdir/file name or a relative path from the src dir whose
@@ -575,12 +575,12 @@ def bundle(self, name: str) -> str:
575575
return remote_path
576576

577577
def _copy_to_remote(self, *, local_path: str, remote_path: str):
578-
if tf_io.gfile.exists(remote_path):
578+
if exists(remote_path):
579579
logging.info("Overwriting existing bundle at %s", remote_path)
580580
else:
581581
logging.info("Uploading bundle to: %s", remote_path)
582-
tf_io.gfile.makedirs(os.path.dirname(remote_path))
583-
tf_io.gfile.copy(local_path, remote_path, overwrite=True)
582+
makedirs(os.path.dirname(remote_path))
583+
copy(local_path, remote_path, overwrite=True)
584584

585585
def _copy_to_local_command(self, *, remote_bundle_id: str, local_bundle_id: str) -> str:
586586
"""Emits a command to copy a bundle from remote to local.

axlearn/cloud/common/quota.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from typing import Protocol
99

1010
import toml
11-
from tensorflow import io as tf_io
1211

1312
from axlearn.cloud.common.types import ProjectResourceMap, ResourceMap
13+
from axlearn.common.file_system import readfile
1414

1515
QUOTA_CONFIG_PATH = "project-quotas/project-quotas.config"
1616

@@ -73,18 +73,17 @@ def get_resource_limits(path: str) -> UserQuotaInfo:
7373
Raises:
7474
ValueError: If unable to parse quota config file.
7575
"""
76-
with tf_io.gfile.GFile(path, mode="r") as f:
77-
cfg = toml.loads(f.read())
78-
if cfg["toml-schema"]["version"] == "1":
79-
total_resources = cfg["total_resources"]
80-
if not isinstance(total_resources, Sequence):
81-
total_resources = [total_resources]
82-
return UserQuotaInfo(
83-
total_resources=total_resources,
84-
project_resources=cfg["project_resources"],
85-
project_membership=cfg["project_membership"],
86-
)
87-
raise ValueError(f"Unsupported schema version {cfg['toml-schema']['version']}")
76+
cfg = toml.loads(readfile(path))
77+
if cfg["toml-schema"]["version"] == "1":
78+
total_resources = cfg["total_resources"]
79+
if not isinstance(total_resources, Sequence):
80+
total_resources = [total_resources]
81+
return UserQuotaInfo(
82+
total_resources=total_resources,
83+
project_resources=cfg["project_resources"],
84+
project_membership=cfg["project_membership"],
85+
)
86+
raise ValueError(f"Unsupported schema version {cfg['toml-schema']['version']}")
8887

8988

9089
def get_user_projects(path: str, user_id: str) -> list[str]:

axlearn/cloud/common/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def copy_blobs(from_prefix: str, *, to_prefix: str):
301301
# pylint: disable-next=import-outside-toplevel
302302
from axlearn.common import file_system as fs
303303

304-
# As tf_io.gfile.copy requires a path to a file when reading from cloud storage,
304+
# As file_system.copy requires a path to a file when reading from cloud storage,
305305
# we traverse the `from_prefix` to find and copy all suffixes.
306306
if not fs.isdir(from_prefix):
307307
# Copy the file.

0 commit comments

Comments
 (0)