Skip to content

Commit

Permalink
[ray] Fix datasets_modules ImportError with Ray Tune (huggingface#1…
Browse files Browse the repository at this point in the history
…2749)

* Fix dynamic_modules ImportError with Ray Tune

* Nit
  • Loading branch information
Yard1 authored Jul 19, 2021
1 parent 534f6eb commit cab3b86
Showing 1 changed file with 30 additions and 1 deletion.
31 changes: 30 additions & 1 deletion src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
"""
Integrations with other Python libraries.
"""
import functools
import importlib.util
import numbers
import os
import sys
import tempfile
from pathlib import Path

from .file_utils import is_datasets_available
from .utils import logging


Expand Down Expand Up @@ -246,8 +249,34 @@ def _objective(trial, local_trainer, checkpoint_dir=None):
"Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
)

trainable = ray.tune.with_parameters(_objective, local_trainer=trainer)

@functools.wraps(trainable)
def dynamic_modules_import_trainable(*args, **kwargs):
"""
Wrapper around ``tune.with_parameters`` to ensure datasets_modules are loaded on each Actor.
Without this, an ImportError will be thrown. See https://github.com/huggingface/transformers/issues/11565.
Assumes that ``_objective``, defined above, is a function.
"""
if is_datasets_available():
import datasets.load

dynamic_modules_path = os.path.join(datasets.load.init_dynamic_modules(), "__init__.py")
# load dynamic_modules from path
spec = importlib.util.spec_from_file_location("datasets_modules", dynamic_modules_path)
datasets_modules = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = datasets_modules
spec.loader.exec_module(datasets_modules)
return trainable(*args, **kwargs)

# special attr set by tune.with_parameters
if hasattr(trainable, "__mixins__"):
dynamic_modules_import_trainable.__mixins__ = trainable.__mixins__

analysis = ray.tune.run(
ray.tune.with_parameters(_objective, local_trainer=trainer),
dynamic_modules_import_trainable,
config=trainer.hp_space(None),
num_samples=n_trials,
**kwargs,
Expand Down

0 comments on commit cab3b86

Please sign in to comment.