Skip to content

Commit

Permalink
Added DagshubCallback (#21404)
Browse files Browse the repository at this point in the history
* integrated logger

* bugifx

* added data

* bugfix

* model + state artifacts should log

* fixed paths

* i lied, trying again

* updated function call

* typo

this is painful :( what a stupid error

* typo

this is painful :( what a stupid error

* pivoted to adding a directory

* silly path bug

* multiple experiments

* migrated to getattr

* syntax fix

* syntax fix

* fixed repo pointer

* fixed path error

* added dataset if dataloader is present, uploaded artifacts

* variable in scope

* removed unnecessary line

* updated error type

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* trimmed unused variables, imports

* style formatting

* removed type conversion reliance

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* reverted accidental line deletion

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
jinensetpal and sgugger authored Feb 1, 2023
1 parent 8d58077 commit 3fadb4b
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None


def is_dagshub_available():
return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")]


def is_fairscale_available():
return importlib.util.find_spec("fairscale") is not None

Expand Down Expand Up @@ -522,6 +526,8 @@ def get_available_reporting_integrations():
integrations.append("azure_ml")
if is_comet_available():
integrations.append("comet_ml")
if is_dagshub_available():
integrations.append("dagshub")
if is_mlflow_available():
integrations.append("mlflow")
if is_neptune_available():
Expand Down Expand Up @@ -1045,6 +1051,55 @@ def __del__(self):
self._ml_flow.end_run()


class DagsHubCallback(MLflowCallback):
"""
A [`TrainerCallback`] that logs to [DagsHub](https://dagshub.com/).
"""

def __init__(self):
super().__init__()
if not is_dagshub_available():
raise ImportError("DagsHubCallback requires dagshub to be installed. Run `pip install dagshub`.")

from dagshub.upload import Repo

self.Repo = Repo

def setup(self, *args, **kwargs):
"""
Setup the DagsHub's Logging integration.
Environment:
HF_DAGSHUB_LOG_ARTIFACTS (`str`, *optional*):
Whether to save the data and model artifacts for the experiment. Default to `False`.
"""

self.log_artifacts = os.getenv("HF_DAGSHUB_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self.name = os.getenv("HF_DAGSHUB_MODEL_NAME") or "main"
self.remote = os.getenv("MLFLOW_TRACKING_URI")
self.repo = self.Repo(
owner=self.remote.split(os.sep)[-2],
name=self.remote.split(os.sep)[-1].split(".")[0],
branch=os.getenv("BRANCH") or "main",
)
self.path = Path("artifacts")

if self.remote is None:
raise RuntimeError(
"DagsHubCallback requires the `MLFLOW_TRACKING_URI` environment variable to be set. Did you run"
" `dagshub.init()`?"
)

super().setup(*args, **kwargs)

def on_train_end(self, args, state, control, **kwargs):
if self.log_artifacts:
if getattr(self, "train_dataloader", None):
torch.save(self.train_dataloader.dataset, os.path.join(args.output_dir, "dataset.pt"))

self.repo.directory(str(self.path)).add_dir(args.output_dir)


class NeptuneMissingConfiguration(Exception):
def __init__(self):
super().__init__(
Expand Down Expand Up @@ -1465,6 +1520,7 @@ def on_save(self, args, state, control, **kwargs):
"wandb": WandbCallback,
"codecarbon": CodeCarbonCallback,
"clearml": ClearMLCallback,
"dagshub": DagsHubCallback,
}


Expand Down

0 comments on commit 3fadb4b

Please sign in to comment.