Skip to content

Commit

Permalink
[train] Fix wandb/comet integration API calls (ray-project#38978)
Browse files Browse the repository at this point in the history
Removes remaining calls to checkpoint.dir_or_data in the wandb/comet integrations

Signed-off-by: Kai Fricke <kai@anyscale.com>
  • Loading branch information
krfricke authored Aug 31, 2023
1 parent 9350b16 commit b248d36
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 15 deletions.
38 changes: 26 additions & 12 deletions python/ray/air/integrations/comet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
from typing import Dict, List

import pyarrow.fs

from ray.train import _use_storage_context
from ray.tune.logger import LoggerCallback
from ray.tune.experiment import Trial
from ray.tune.utils import flatten_dict
Expand Down Expand Up @@ -223,19 +226,30 @@ def log_trial_save(self, trial: "Trial"):
name=f"checkpoint_{(str(trial))}", artifact_type="model"
)

checkpoint_root = None

if _use_storage_context():
if isinstance(trial.checkpoint.filesystem, pyarrow.fs.LocalFileSystem):
checkpoint_root = trial.checkpoint.path
# Todo: For other filesystems, we may want to use
# artifact.add_remote() instead. However, this requires a full
# URI. We can add this once we have a way to retrieve it.
else:
checkpoint_root = trial.checkpoint.dir_or_data

# Walk through checkpoint directory and add all files to artifact
checkpoint_root = trial.checkpoint.dir_or_data
for root, dirs, files in os.walk(checkpoint_root):
rel_root = os.path.relpath(root, checkpoint_root)
for file in files:
local_file = os.path.join(checkpoint_root, rel_root, file)
logical_path = os.path.join(rel_root, file)

# Strip leading `./`
if logical_path.startswith("./"):
logical_path = logical_path[2:]

artifact.add(local_file, logical_path=logical_path)
if checkpoint_root:
for root, dirs, files in os.walk(checkpoint_root):
rel_root = os.path.relpath(root, checkpoint_root)
for file in files:
local_file = os.path.join(checkpoint_root, rel_root, file)
logical_path = os.path.join(rel_root, file)

# Strip leading `./`
if logical_path.startswith("./"):
logical_path = logical_path[2:]

artifact.add(local_file, logical_path=logical_path)

experiment.log_artifact(artifact)

Expand Down
15 changes: 12 additions & 3 deletions python/ray/air/integrations/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import numpy as np
from numbers import Number

import pyarrow.fs

from types import ModuleType
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

Expand All @@ -16,6 +18,7 @@
from ray.air._internal import usage as air_usage
from ray.air.util.node import _force_on_current_node

from ray.train import _use_storage_context
from ray.tune.logger import LoggerCallback
from ray.tune.utils import flatten_dict
from ray.tune.experiment import Trial
Expand Down Expand Up @@ -671,9 +674,15 @@ def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):

def log_trial_save(self, trial: "Trial"):
if self.upload_checkpoints and trial.checkpoint:
self._trial_queues[trial].put(
(_QueueItem.CHECKPOINT, trial.checkpoint.dir_or_data)
)
checkpoint_root = None
if _use_storage_context():
if isinstance(trial.checkpoint.filesystem, pyarrow.fs.LocalFileSystem):
checkpoint_root = trial.checkpoint.path
else:
checkpoint_root = trial.checkpoint.dir_or_data

if checkpoint_root:
self._trial_queues[trial].put((_QueueItem.CHECKPOINT, checkpoint_root))

def log_trial_end(self, trial: "Trial", failed: bool = False):
self._signal_logging_actor_stop(trial=trial)
Expand Down

0 comments on commit b248d36

Please sign in to comment.