66
66
from typing import IO , Any , Optional , Union
67
67
68
68
from absl import flags , logging
69
- from tensorflow import errors as tf_errors
70
- from tensorflow import io as tf_io
71
69
from tensorflow import nest as tf_nest
72
70
73
71
# tensorflow_io import is necessary for tf_io to understand s3:// scheme.
92
90
config_for_function ,
93
91
maybe_instantiate ,
94
92
)
93
+ from axlearn .common .file_system import NotFoundError , copy , exists , isdir , listdir , makedirs
94
+ from axlearn .common .file_system import open as fs_open
95
+ from axlearn .common .file_system import readfile , remove , rmtree
95
96
from axlearn .common .utils import Nested
96
97
97
98
_LATEST_BASTION_VERSION = 1 # Determines job schema (see JobSpec).
@@ -113,34 +114,33 @@ def bastion_job_flags(flag_values: flags.FlagValues = FLAGS):
113
114
114
115
115
116
# The following functions, `_download`, `_readfile`, `_listdir`, and `_remove`, can be patched to
116
- # support alternative storages that cannot be accessed via gfile .
117
+ # support alternative storages that cannot be accessed via file_system .
117
118
#
118
119
# TODO(ruoming): refactor them to a `BastionDirStorage` class.
119
120
def _download (path : str , local_file : str ):
120
- tf_io . gfile . copy (path , local_file , overwrite = True )
121
+ copy (path , local_file , overwrite = True )
121
122
122
123
123
124
def _readfile (path : str ) -> str :
124
- with tf_io .gfile .GFile (path , mode = "r" ) as f :
125
- return f .read ()
125
+ return readfile (path )
126
126
127
127
128
128
def _listdir (path : str ) -> list [str ]:
129
- """Wraps tf_io.gfile .listdir by returning empty list if dir is not found."""
129
+ """Wraps file_system .listdir by returning empty list if dir is not found."""
130
130
try :
131
- return tf_io . gfile . listdir (path )
132
- except tf_errors . NotFoundError :
131
+ return listdir (path )
132
+ except NotFoundError :
133
133
return []
134
134
135
135
136
136
def _remove (path : str ):
137
- """Wraps tf_io.gfile .remove by catching not found errors."""
137
+ """Wraps file_system .remove by catching not found errors."""
138
138
try :
139
- if tf_io . gfile . isdir (path ):
140
- tf_io . gfile . rmtree (path )
139
+ if isdir (path ):
140
+ rmtree (path )
141
141
else :
142
- tf_io . gfile . remove (path )
143
- except tf_errors . NotFoundError :
142
+ remove (path )
143
+ except NotFoundError :
144
144
pass
145
145
146
146
@@ -357,7 +357,7 @@ def _upload_jobspec(spec: JobSpec, *, remote_dir: str, local_dir: str = _JOB_DIR
357
357
local_file = os .path .join (local_dir , spec .name )
358
358
remote_file = os .path .join (remote_dir , spec .name )
359
359
serialize_jobspec (spec , local_file )
360
- tf_io . gfile . copy (local_file , remote_file , overwrite = True )
360
+ copy (local_file , remote_file , overwrite = True )
361
361
362
362
363
363
@dataclasses .dataclass
@@ -440,7 +440,7 @@ def _download_job_state(job_name: str, *, remote_dir: str) -> JobState:
440
440
"""Loads job state from gs path."""
441
441
remote_file = os .path .join (remote_dir , job_name )
442
442
try :
443
- # Note: tf_io.gfile. GFile seems to hit libcurl errors with ThreadPoolExecutor.
443
+ # Note: GFile seems to hit libcurl errors with ThreadPoolExecutor.
444
444
contents = _readfile (remote_file )
445
445
try :
446
446
state = json .loads (contents )
@@ -449,7 +449,7 @@ def _download_job_state(job_name: str, *, remote_dir: str) -> JobState:
449
449
state = dict (status = contents )
450
450
state ["status" ] = JobStatus [state ["status" ].strip ().upper ()]
451
451
return JobState (** state )
452
- except tf_errors . NotFoundError :
452
+ except NotFoundError :
453
453
# No job state, defaults to PENDING.
454
454
return JobState (status = JobStatus .PENDING )
455
455
@@ -458,7 +458,7 @@ def _upload_job_state(job_name: str, state: JobState, *, remote_dir: str, verbos
458
458
"""Uploads job state to gs path."""
459
459
remote_file = os .path .join (remote_dir , job_name )
460
460
logging .log_if (logging .INFO , "Writing %s to %s." , verbose , state .status .name , remote_file )
461
- with tf_io . gfile . GFile (remote_file , mode = "w" ) as f :
461
+ with fs_open (remote_file , mode = "w" ) as f :
462
462
json .dump (dataclasses .asdict (state ), f )
463
463
464
464
@@ -471,7 +471,7 @@ def _start_command(job: Job, *, remote_log_dir: str, env_vars: dict):
471
471
local_log = os .path .join (_LOG_DIR , job .spec .name )
472
472
try :
473
473
_download (remote_log , local_log )
474
- except tf_errors . NotFoundError :
474
+ except NotFoundError :
475
475
pass
476
476
# Pipe all outputs to the local _LOG_DIR.
477
477
job .command_proc = _piped_popen (
@@ -648,9 +648,8 @@ def _load_runtime_options(bastion_dir: str) -> dict[str, Any]:
648
648
"""Loads runtime option(s) from file, or returns {} on failure."""
649
649
flag_file = os .path .join (bastion_dir , "runtime_options" )
650
650
try :
651
- with tf_io .gfile .GFile (flag_file , "r" ) as f :
652
- return json .load (f )
653
- except (tf_errors .NotFoundError , json .JSONDecodeError ) as e :
651
+ return json .loads (readfile (flag_file ))
652
+ except (NotFoundError , json .JSONDecodeError ) as e :
654
653
logging .warning ("Failed to load runtime options: %s" , e )
655
654
return {}
656
655
@@ -660,7 +659,7 @@ def set_runtime_options(bastion_dir: str, **kwargs) -> Nested[Any]:
660
659
runtime_options = _load_runtime_options (bastion_dir )
661
660
runtime_options = merge (runtime_options , kwargs )
662
661
flag_file = os .path .join (bastion_dir , "runtime_options" )
663
- with tf_io . gfile . GFile (flag_file , "w" ) as f :
662
+ with fs_open (flag_file , "w" ) as f :
664
663
json .dump (runtime_options , f )
665
664
logging .info ("Updated runtime options: %s" , runtime_options )
666
665
return runtime_options
@@ -699,11 +698,11 @@ def __init__(self, cfg: Config):
699
698
self ._job_dir = os .path .join (self ._output_dir , "jobs" )
700
699
# Remote history dir. Ensure trailing slash.
701
700
self ._job_history_dir = os .path .join (self ._output_dir , "history" , "jobs" )
702
- tf_io . gfile . makedirs (self ._job_history_dir )
701
+ makedirs (self ._job_history_dir )
703
702
self ._project_history_dir = os .path .join (self ._output_dir , "history" , "projects" )
704
- tf_io . gfile . makedirs (self ._project_history_dir )
703
+ makedirs (self ._project_history_dir )
705
704
self ._scheduler_history_dir = os .path .join (self ._output_dir , "history" , "scheduler" )
706
- tf_io . gfile . makedirs (self ._scheduler_history_dir )
705
+ makedirs (self ._scheduler_history_dir )
707
706
# Mapping from project_id to previous job verdicts.
708
707
self ._project_history_previous_verdicts = {}
709
708
# Jobs that have fully completed.
@@ -733,7 +732,7 @@ def __init__(self, cfg: Config):
733
732
self ._event_publisher = maybe_instantiate (cfg .event_publisher )
734
733
735
734
def _append_to_job_history (self , job : Job , * , msg : str , state : JobLifecycleState ):
736
- with tf_io . gfile . GFile (os .path .join (self ._job_history_dir , f"{ job .spec .name } " ), "a" ) as f :
735
+ with fs_open (os .path .join (self ._job_history_dir , f"{ job .spec .name } " ), "a" ) as f :
737
736
curr_time = datetime .now (timezone .utc ).strftime ("%m%d %H:%M:%S" )
738
737
f .write (f"{ curr_time } { msg } \n " )
739
738
# Publish event into queue.
@@ -751,7 +750,7 @@ def _append_to_history(
751
750
self , jobs : dict [str , JobMetadata ], schedule_results : BaseScheduler .ScheduleResults
752
751
):
753
752
now = datetime .now (timezone .utc )
754
- with tf_io . gfile . GFile (
753
+ with fs_open (
755
754
os .path .join (self ._scheduler_history_dir , now .strftime ("%Y%m%d-%H%M%S" )), "a"
756
755
) as f :
757
756
for job_id , verdict in schedule_results .job_verdicts .items ():
@@ -794,8 +793,8 @@ def resource_str(resource_map: ResourceMap) -> str:
794
793
)
795
794
796
795
project_dir = os .path .join (self ._project_history_dir , project_id )
797
- tf_io . gfile . makedirs (project_dir )
798
- with tf_io . gfile . GFile (os .path .join (project_dir , now .strftime ("%Y%m%d" )), "a" ) as f :
796
+ makedirs (project_dir )
797
+ with fs_open (os .path .join (project_dir , now .strftime ("%Y%m%d" )), "a" ) as f :
799
798
curr_time = now .strftime ("%m%d %H:%M:%S" )
800
799
f .write (f"{ curr_time } \n " )
801
800
f .write (f"Effective limits: { resource_str (limits )} \n " )
@@ -838,7 +837,7 @@ def _wait_and_close_proc(self, proc: _PipedProcess, kill: bool = False):
838
837
proc .fd .close ()
839
838
# Upload outputs to log dir.
840
839
_catch_with_error_log (
841
- tf_io . gfile . copy ,
840
+ copy ,
842
841
proc .fd .name ,
843
842
os .path .join (self ._log_dir , os .path .basename (proc .fd .name )),
844
843
overwrite = True ,
@@ -1301,7 +1300,7 @@ def list_jobs(self):
1301
1300
def cancel_job (self , job_name : str ):
1302
1301
try :
1303
1302
jobspec = os .path .join (self .active_job_dir , job_name )
1304
- if not tf_io . gfile . exists (jobspec ):
1303
+ if not exists (jobspec ):
1305
1304
raise ValueError (f"Unable to locate jobspec { jobspec } " )
1306
1305
_upload_job_state (
1307
1306
job_name ,
@@ -1310,7 +1309,7 @@ def cancel_job(self, job_name: str):
1310
1309
)
1311
1310
logging .info ("Job %s is cancelling." , job_name )
1312
1311
# Poll for jobspec to be removed.
1313
- while tf_io . gfile . exists (jobspec ):
1312
+ while exists (jobspec ):
1314
1313
logging .info ("Waiting for job to stop (which usually takes a few minutes)..." )
1315
1314
time .sleep (10 )
1316
1315
logging .info ("Job is stopped." )
@@ -1321,15 +1320,15 @@ def submit_job(self, job_name: str, *, job_spec_file: str):
1321
1320
if not is_valid_job_name (job_name ):
1322
1321
raise ValueError (f"{ job_name } is not a valid job name." )
1323
1322
dst = os .path .join (self .active_job_dir , job_name )
1324
- if tf_io . gfile . exists (dst ):
1323
+ if exists (dst ):
1325
1324
logging .info ("\n \n Note: Job is already running. To restart it, cancel the job first.\n " )
1326
1325
else :
1327
1326
# Upload the job for bastion to pickup.
1328
- tf_io . gfile . copy (job_spec_file , dst )
1327
+ copy (job_spec_file , dst )
1329
1328
1330
1329
def get_job (self , job_name : str ) -> JobSpec :
1331
1330
job_path = os .path .join (self .active_job_dir , job_name )
1332
- if not tf_io . gfile . exists (job_path ):
1331
+ if not exists (job_path ):
1333
1332
raise ValueError (f"Unable to locate jobspec { job_path } " )
1334
1333
1335
1334
with tempfile .TemporaryDirectory () as tmpdir :
@@ -1338,13 +1337,13 @@ def get_job(self, job_name: str) -> JobSpec:
1338
1337
1339
1338
def update_job (self , job_name : str , * , job_spec : JobSpec ) -> JobSpec :
1340
1339
dst = os .path .join (self .active_job_dir , job_name )
1341
- if not tf_io . gfile . exists (dst ):
1340
+ if not exists (dst ):
1342
1341
raise ValueError (f"Unable to locate jobspec { dst } " )
1343
1342
1344
1343
with tempfile .NamedTemporaryFile ("w" ) as f :
1345
1344
serialize_jobspec (job_spec , f )
1346
1345
# Upload the job for bastion to pickup.
1347
- tf_io . gfile . copy (f .name , dst , overwrite = True )
1346
+ copy (f .name , dst , overwrite = True )
1348
1347
logging .info ("Job %s is updating." , job_name )
1349
1348
1350
1349
return job_spec
0 commit comments