-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tf2estimator on Pyspark support tensorboard #3959
Changes from all commits
099c739
6deaad6
6603cba
d1e8574
2fe944f
a8678eb
15d9999
1977311
4ff9f5f
e3a6e4c
ed250f5
c4c67df
ed19ce8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,8 @@ | |
from bigdl.orca.data.utils import ray_partition_get_data_label | ||
from bigdl.orca.data.file import put_local_dir_to_remote | ||
from bigdl.orca.learn.utils import save_pkl, duplicate_stdout_stderr_to_file,\ | ||
get_specific_object_from_callbacks, get_replaced_path, get_rank | ||
get_specific_object_from_callbacks, get_replaced_path, get_rank, \ | ||
process_tensorboard_in_callbacks | ||
from bigdl.orca.learn.log_monitor import LogMonitor | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -274,7 +275,6 @@ def distributed_train_func(self, data_creator, config, epochs=1, verbose=1, | |
config=config, epochs=epochs, | ||
steps_per_epoch=steps_per_epoch, | ||
validation_steps=validation_steps) | ||
checkpoint = None | ||
if callbacks: | ||
checkpoint = get_specific_object_from_callbacks(tf.keras.callbacks.ModelCheckpoint, | ||
callbacks) | ||
|
@@ -283,6 +283,8 @@ def distributed_train_func(self, data_creator, config, epochs=1, verbose=1, | |
replaced_checkpoint_path = get_replaced_path(checkpoint.filepath) | ||
checkpoint.filepath = replaced_checkpoint_path | ||
|
||
replaced_log_dir = process_tensorboard_in_callbacks(callbacks, "fit", self.rank) | ||
|
||
history = model.fit(train_dataset, | ||
epochs=epochs, | ||
verbose=verbose, | ||
|
@@ -294,14 +296,22 @@ def distributed_train_func(self, data_creator, config, epochs=1, verbose=1, | |
validation_steps=validation_steps, | ||
validation_freq=validation_freq) | ||
|
||
if checkpoint: | ||
try: | ||
if self.rank == 0: | ||
put_local_dir_to_remote(os.path.dirname(replaced_checkpoint_path), | ||
original_checkpoint_dir) | ||
finally: | ||
shutil.rmtree(os.path.dirname(replaced_checkpoint_path)) | ||
|
||
if callbacks: | ||
if checkpoint: | ||
checkpoint_copied = False | ||
try: | ||
if self.rank == 0: | ||
put_local_dir_to_remote(os.path.dirname(replaced_checkpoint_path), | ||
original_checkpoint_dir) | ||
checkpoint_copied = True | ||
except Exception: | ||
logger.warning("Error when copy local checkpoint {} to {}, " | ||
"please get the local checkpoint manually" | ||
.format(replaced_checkpoint_path, original_checkpoint_dir)) | ||
if checkpoint_copied: | ||
shutil.rmtree(os.path.dirname(replaced_checkpoint_path)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If How about printing a warning stating that there is an error and the checkpoint is located at xxx and then users will have a chance to get them manually. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
if replaced_log_dir and os.path.exists(replaced_log_dir): | ||
shutil.rmtree(replaced_log_dir) | ||
return (model, history) | ||
|
||
def step(self, data_creator, epochs=1, batch_size=32, verbose=1, | ||
|
@@ -374,6 +384,8 @@ def validate(self, data_creator, batch_size=32, verbose=1, sample_weight=None, | |
dataset = dataset_handler.handle_dataset_validation(data_creator, | ||
config=config, | ||
steps=steps) | ||
if callbacks: | ||
replaced_log_dir = process_tensorboard_in_callbacks(callbacks, "evaluate", self.rank) | ||
|
||
params = dict( | ||
verbose=verbose, | ||
|
@@ -397,6 +409,11 @@ def validate(self, data_creator, batch_size=32, verbose=1, sample_weight=None, | |
else: | ||
stats = {"results": results} | ||
|
||
# clean temporary dir for tensorboard | ||
if callbacks: | ||
if replaced_log_dir and os.path.exists(replaced_log_dir): | ||
shutil.rmtree(replaced_log_dir) | ||
|
||
if self.rank == 0: | ||
if self.need_to_log_to_driver: | ||
LogMonitor.stop_log_monitor(self.log_path, self.logger_thread, self.thread_stop) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to just use pyarrow or hadoop command instead of interleaving them together?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using pyarrow to copy hdfs tree is a little complex. will change to use command all the ways.