Skip to content

Commit

Permalink
remove SharedDDP as it is deprecated (huggingface#25702)
Browse files Browse the repository at this point in the history
* remove SharedDDP as it was drepracated

* apply review suggestion

* make style

* Oops,forgot to remove the compute_loss context manager in Seq2SeqTrainer.

* remove the unnecessary conditional statement

* keep the logic of IPEX

* clean code

* mix precision setup & make fixup

---------

Co-authored-by: statelesshz <jihuazhong1@huawei.com>
  • Loading branch information
statelesshz and statelesshz authored Oct 6, 2023
1 parent e840aa6 commit 27597fe
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 299 deletions.
14 changes: 1 addition & 13 deletions examples/legacy/seq2seq/seq2seq_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from torch.utils.data import DistributedSampler, RandomSampler

from transformers import PreTrainedModel, Trainer, logging
from transformers.integrations import is_fairscale_available
from transformers.models.fsmt.configuration_fsmt import FSMTConfig
from transformers.optimization import (
Adafactor,
Expand All @@ -36,10 +35,6 @@
from transformers.utils import is_torch_tpu_available


if is_fairscale_available():
from fairscale.optim import OSS


logger = logging.get_logger(__name__)

arg_to_scheduler = {
Expand Down Expand Up @@ -118,14 +113,7 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
"eps": self.args.adam_epsilon,
}
optimizer_kwargs["lr"] = self.args.learning_rate
if self.sharded_ddp:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)

if self.lr_scheduler is None:
self.lr_scheduler = self._get_lr_scheduler(num_training_steps)
Expand Down
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@
"diffusers",
"dill<0.3.5",
"evaluate>=0.2.0",
"fairscale>0.3",
"faiss-cpu",
"fastapi",
"filelock",
Expand Down Expand Up @@ -275,7 +274,6 @@ def run(self):

extras["sagemaker"] = deps_list("sagemaker")
extras["deepspeed"] = deps_list("deepspeed") + extras["accelerate"]
extras["fairscale"] = deps_list("fairscale")
extras["optuna"] = deps_list("optuna")
extras["ray"] = deps_list("ray[tune]")
extras["sigopt"] = deps_list("sigopt")
Expand Down
1 change: 0 additions & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"diffusers": "diffusers",
"dill": "dill<0.3.5",
"evaluate": "evaluate>=0.2.0",
"fairscale": "fairscale>0.3",
"faiss-cpu": "faiss-cpu",
"fastapi": "fastapi",
"filelock": "filelock",
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
"is_codecarbon_available",
"is_comet_available",
"is_dagshub_available",
"is_fairscale_available",
"is_flyte_deck_standard_available",
"is_flytekit_available",
"is_mlflow_available",
Expand Down Expand Up @@ -118,7 +117,6 @@
is_codecarbon_available,
is_comet_available,
is_dagshub_available,
is_fairscale_available,
is_flyte_deck_standard_available,
is_flytekit_available,
is_mlflow_available,
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,6 @@ def is_dagshub_available():
return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")]


def is_fairscale_available():
return importlib.util.find_spec("fairscale") is not None


def is_neptune_available():
return _has_neptune

Expand Down
8 changes: 0 additions & 8 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

from .integrations import (
is_clearml_available,
is_fairscale_available,
is_optuna_available,
is_ray_available,
is_sigopt_available,
Expand Down Expand Up @@ -871,13 +870,6 @@ def require_deepspeed(test_case):
return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case)


def require_fairscale(test_case):
"""
Decorator marking a test that requires fairscale
"""
return unittest.skipUnless(is_fairscale_available(), "test requires fairscale")(test_case)


def require_apex(test_case):
"""
Decorator marking a test that requires apex
Expand Down
Loading

0 comments on commit 27597fe

Please sign in to comment.