diff --git a/examples/legacy/seq2seq/seq2seq_trainer.py b/examples/legacy/seq2seq/seq2seq_trainer.py index dbf12725f2db07..6b52d338af402f 100644 --- a/examples/legacy/seq2seq/seq2seq_trainer.py +++ b/examples/legacy/seq2seq/seq2seq_trainer.py @@ -19,7 +19,6 @@ from torch.utils.data import DistributedSampler, RandomSampler from transformers import PreTrainedModel, Trainer, logging -from transformers.integrations import is_fairscale_available from transformers.models.fsmt.configuration_fsmt import FSMTConfig from transformers.optimization import ( Adafactor, @@ -36,10 +35,6 @@ from transformers.utils import is_torch_tpu_available -if is_fairscale_available(): - from fairscale.optim import OSS - - logger = logging.get_logger(__name__) arg_to_scheduler = { @@ -118,14 +113,7 @@ def create_optimizer_and_scheduler(self, num_training_steps: int): "eps": self.args.adam_epsilon, } optimizer_kwargs["lr"] = self.args.learning_rate - if self.sharded_ddp: - self.optimizer = OSS( - params=optimizer_grouped_parameters, - optim=optimizer_cls, - **optimizer_kwargs, - ) - else: - self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if self.lr_scheduler is None: self.lr_scheduler = self._get_lr_scheduler(num_training_steps) diff --git a/setup.py b/setup.py index 866ddb37b633dc..444efbf268d1b6 100644 --- a/setup.py +++ b/setup.py @@ -109,7 +109,6 @@ "diffusers", "dill<0.3.5", "evaluate>=0.2.0", - "fairscale>0.3", "faiss-cpu", "fastapi", "filelock", @@ -275,7 +274,6 @@ def run(self): extras["sagemaker"] = deps_list("sagemaker") extras["deepspeed"] = deps_list("deepspeed") + extras["accelerate"] -extras["fairscale"] = deps_list("fairscale") extras["optuna"] = deps_list("optuna") extras["ray"] = deps_list("ray[tune]") extras["sigopt"] = deps_list("sigopt") diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 391612ed3c26e7..20dacb3cf0d4e1 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -16,7 +16,6 @@ "diffusers": "diffusers", "dill": "dill<0.3.5", "evaluate": "evaluate>=0.2.0", - "fairscale": "fairscale>0.3", "faiss-cpu": "faiss-cpu", "fastapi": "fastapi", "filelock": "filelock", diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index bf24514e718134..ddd36955b3bf36 100644 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -57,7 +57,6 @@ "is_codecarbon_available", "is_comet_available", "is_dagshub_available", - "is_fairscale_available", "is_flyte_deck_standard_available", "is_flytekit_available", "is_mlflow_available", @@ -118,7 +117,6 @@ is_codecarbon_available, is_comet_available, is_dagshub_available, - is_fairscale_available, is_flyte_deck_standard_available, is_flytekit_available, is_mlflow_available, diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index ec4d177ce5a743..10f86ee4198032 100644 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -134,10 +134,6 @@ def is_dagshub_available(): return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")] -def is_fairscale_available(): - return importlib.util.find_spec("fairscale") is not None - - def is_neptune_available(): return _has_neptune diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index adbf7fe32a4b80..341e6cd1688f03 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -42,7 +42,6 @@ from .integrations import ( is_clearml_available, - is_fairscale_available, is_optuna_available, is_ray_available, is_sigopt_available, @@ -871,13 +870,6 @@ def require_deepspeed(test_case): return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case) -def require_fairscale(test_case): - """ - Decorator marking a test that requires fairscale - """ - return unittest.skipUnless(is_fairscale_available(), "test requires fairscale")(test_case) - - def require_apex(test_case): """ Decorator marking a test that requires apex diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 416b10cec5cb8f..9fce06968edc52 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -40,7 +40,6 @@ from .integrations import ( get_reporting_integration_callbacks, hp_params, - is_fairscale_available, ) # isort: on @@ -58,7 +57,6 @@ from .configuration_utils import PretrainedConfig from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .debug_utils import DebugOption, DebugUnderflowOverflow -from .dependency_versions_check import dep_version_check from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available from .modelcard import TrainingSummary @@ -107,7 +105,6 @@ IntervalStrategy, PredictionOutput, RemoveColumnsCollator, - ShardedDDPOption, TrainerMemoryTracker, TrainOutput, default_compute_objective, @@ -171,15 +168,6 @@ import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met -if is_fairscale_available(): - dep_version_check("fairscale") - import fairscale - from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP - from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP - from fairscale.nn.wrap import auto_wrap - from fairscale.optim import OSS - from fairscale.optim.grad_scaler import ShardedGradScaler - if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp @@ -420,33 +408,6 @@ def __init__( " model, please make sure that you have installed `bitsandbytes>=0.37.0`. " ) - # Setup Sharded DDP training - self.sharded_ddp = None - if len(args.sharded_ddp) > 0: - if self.is_deepspeed_enabled: - raise ValueError( - "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." - ) - if len(args.fsdp) > 0: - raise ValueError( - "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags." - ) - if args.parallel_mode != ParallelMode.DISTRIBUTED: - raise ValueError("Using sharded DDP only works in distributed training.") - elif not is_fairscale_available(): - raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") - elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None: - raise ImportError( - "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found " - f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`." - ) - elif ShardedDDPOption.SIMPLE in args.sharded_ddp: - self.sharded_ddp = ShardedDDPOption.SIMPLE - elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp: - self.sharded_ddp = ShardedDDPOption.ZERO_DP_2 - elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: - self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 - self.fsdp = None if len(args.fsdp) > 0: if self.is_deepspeed_enabled: @@ -488,14 +449,12 @@ def __init__( # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, # and we only use deepspeed for training at the moment # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first - # 4. Sharded DDP - same as MP - # 5. FSDP - same as MP + # 4. FSDP - same as MP self.place_model_on_device = args.place_model_on_device if ( self.is_model_parallel or self.is_deepspeed_enabled or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) - or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) or (self.fsdp is not None) or self.is_fsdp_enabled ): @@ -545,11 +504,11 @@ def __init__( " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." ) - if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and ( + if (self.is_deepspeed_enabled or (self.fsdp is not None)) and ( self.optimizer is not None or self.lr_scheduler is not None ): raise RuntimeError( - "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled." + "Passing `optimizers` is not allowed if Deepspeed or PyTorch FSDP is enabled." "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) @@ -592,7 +551,6 @@ def __init__( # Mixed precision setup self.use_apex = False - self.use_cuda_amp = False self.use_cpu_amp = False # Mixed precision setup for SageMaker Model Parallel @@ -617,33 +575,19 @@ def __init__( f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." ) - - if (args.fp16 or args.bf16) and self.sharded_ddp is not None: - if args.half_precision_backend == "auto": - if args.device == torch.device("cpu"): - if args.fp16: - raise ValueError("Tried to use `fp16` but it is not supported on cpu") - else: - args.half_precision_backend = "cpu_amp" + if (args.fp16 or args.bf16) and args.half_precision_backend == "auto": + if args.device == torch.device("cpu"): + if args.fp16: + raise ValueError("Tried to use `fp16` but it is not supported on cpu") else: - args.half_precision_backend = "cuda_amp" - + args.half_precision_backend = "cpu_amp" logger.info(f"Using {args.half_precision_backend} half precision backend") - self.do_grad_scaling = False if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()): # deepspeed and SageMaker Model Parallel manage their own half precision - if self.sharded_ddp is not None: - if args.half_precision_backend == "cuda_amp": - self.use_cuda_amp = True - self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 - # bf16 does not need grad scaling - self.do_grad_scaling = self.amp_dtype == torch.float16 - if self.do_grad_scaling: - self.scaler = ShardedGradScaler() - elif args.half_precision_backend == "cpu_amp": - self.use_cpu_amp = True - self.amp_dtype = torch.bfloat16 + if args.half_precision_backend == "cpu_amp": + self.use_cpu_amp = True + self.amp_dtype = torch.bfloat16 elif args.half_precision_backend == "apex": if not is_apex_available(): raise ImportError( @@ -652,18 +596,6 @@ def __init__( ) self.use_apex = True - # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error. - if ( - is_sagemaker_mp_enabled() - and self.use_cuda_amp - and args.max_grad_norm is not None - and args.max_grad_norm > 0 - ): - raise ValueError( - "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass " - "along 'max_grad_norm': 0 in your hyperparameters." - ) - # Label smoothing if self.args.label_smoothing_factor != 0: self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) @@ -994,27 +926,20 @@ def create_optimizer(self): optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) - if self.sharded_ddp == ShardedDDPOption.SIMPLE: - self.optimizer = OSS( - params=optimizer_grouped_parameters, - optim=optimizer_cls, - **optimizer_kwargs, - ) - else: - self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) - if optimizer_cls.__name__ == "Adam8bit": - import bitsandbytes + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes - manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() - skipped = 0 - for module in opt_model.modules(): - if isinstance(module, nn.Embedding): - skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) - logger.info(f"skipped {module}: {skipped/2**20}M params") - manager.register_module_override(module, "weight", {"optim_bits": 32}) - logger.debug(f"bitsandbytes: will optimize {module} in fp32") - logger.info(f"skipped: {skipped/2**20}M params") + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + logger.info(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override(module, "weight", {"optim_bits": 32}) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") + logger.info(f"skipped: {skipped/2**20}M params") if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer(self.optimizer) @@ -1333,7 +1258,6 @@ def torch_jit_model_eval(self, model, dataloader, training=False): jit_model(**example_batch) model = jit_model self.use_cpu_amp = False - self.use_cuda_amp = False except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e: logger.warning(f"failed to use PyTorch jit mode due to: {e}.") @@ -1396,25 +1320,8 @@ def _wrap_model(self, model, training=True, dataloader=None): return model # Distributed training (should be after apex fp16 initialization) - if self.sharded_ddp is not None: - # Sharded DDP! - if self.sharded_ddp == ShardedDDPOption.SIMPLE: - model = ShardedDDP(model, self.optimizer) - else: - mixed_precision = self.args.fp16 or self.args.bf16 - cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp - zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 - # XXX: Breaking the self.model convention but I see no way around it for now. - if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp: - model = auto_wrap(model) - self.model = model = FullyShardedDDP( - model, - mixed_precision=mixed_precision, - reshard_after_forward=zero_3, - cpu_offload=cpu_offload, - ).to(self.args.device) # Distributed training using PyTorch FSDP - elif self.fsdp is not None and self.args.fsdp_config["xla"]: + if self.fsdp is not None and self.args.fsdp_config["xla"]: try: from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP from torch_xla.distributed.fsdp import checkpoint_module @@ -1669,13 +1576,7 @@ def _inner_training_loop( else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa - delay_optimizer_creation = ( - self.sharded_ddp is not None - and self.sharded_ddp != ShardedDDPOption.SIMPLE - or is_sagemaker_mp_enabled() - or self.fsdp is not None - or self.is_fsdp_enabled - ) + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.fsdp is not None or self.is_fsdp_enabled # We need to reset the scheduler, as its parameters may be different on subsequent calls if self._created_lr_scheduler: @@ -1716,7 +1617,7 @@ def _inner_training_loop( # as the model is wrapped, don't use `accelerator.prepare` # this is for unhandled cases such as - # Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX + # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX use_accelerator_prepare = True if model is self.model else False if delay_optimizer_creation: @@ -1932,14 +1833,6 @@ def _inner_training_loop( if args.max_grad_norm is not None and args.max_grad_norm > 0: # deepspeed does its own clipping - if self.do_grad_scaling: - # Reduce gradients first for XLA - if is_torch_tpu_available(): - gradients = xm._fetch_gradients(self.optimizer) - xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) - # AMP: gradients need unscaling - self.scaler.unscale_(self.optimizer) - if is_sagemaker_mp_enabled() and args.fp16: self.optimizer.clip_master_grads(args.max_grad_norm) elif hasattr(self.optimizer, "clip_grad_norm"): @@ -1961,24 +1854,8 @@ def _inner_training_loop( ) # Optimizer step - optimizer_was_run = True - if is_torch_tpu_available(): - if self.do_grad_scaling: - self.scaler.step(self.optimizer) - self.scaler.update() - else: - # tpu-comment: accelerate wrapped optimizers call xm.optimizer_step - self.optimizer.step() - elif self.do_grad_scaling: - scale_before = self.scaler.get_scale() - self.scaler.step(self.optimizer) - self.scaler.update() - scale_after = self.scaler.get_scale() - optimizer_was_run = scale_before <= scale_after - else: - self.optimizer.step() - optimizer_was_run = not self.accelerator.optimizer_step_was_skipped - + self.optimizer.step() + optimizer_was_run = not self.accelerator.optimizer_step_was_skipped if optimizer_was_run: # Delay optimizer scheduling until metrics are generated if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): @@ -2408,9 +2285,6 @@ def _save_checkpoint(self, model, trial, metrics=None): self.model_wrapped.save_checkpoint(output_dir) # Save optimizer and scheduler - if self.sharded_ddp == ShardedDDPOption.SIMPLE: - self.optimizer.consolidate_state_dict() - if self.fsdp or self.is_fsdp_enabled: if self.is_fsdp_enabled: save_fsdp_optimizer( @@ -2455,8 +2329,6 @@ def _save_checkpoint(self, model, trial, metrics=None): with warnings.catch_warnings(record=True) as caught_warnings: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) - if self.do_grad_scaling: - torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: @@ -2600,8 +2472,6 @@ def opt_load_hook(mod, opt): with warnings.catch_warnings(record=True) as caught_warnings: self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) reissue_pt_warnings(caught_warnings) - if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): - self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME))) def hyperparameter_search( self, @@ -2744,12 +2614,8 @@ def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired arguments, depending on the situation. """ - if self.use_cuda_amp or self.use_cpu_amp: - ctx_manager = ( - torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) - if self.use_cpu_amp - else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) - ) + if self.use_cpu_amp: + ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) else: ctx_manager = contextlib.nullcontext() @@ -2786,9 +2652,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training - if self.do_grad_scaling: - self.scaler.scale(loss).backward() - elif self.use_apex: + if self.use_apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: @@ -2872,12 +2736,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa if IS_SAGEMAKER_MP_POST_1_10: # 'user_content.pt' indicates model state_dict saved with smp >= 1.10 Path(os.path.join(output_dir, "user_content.pt")).touch() - elif ( - ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp - or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp - or self.fsdp is not None - or self.is_fsdp_enabled - ): + elif self.fsdp is not None or self.is_fsdp_enabled: state_dict = self.model.state_dict() if not self.is_fsdp_enabled else {} if self.args.should_save: self._save(output_dir, state_dict=state_dict) diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index 29643cc896916e..aaff31a2dc9e29 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -266,7 +266,6 @@ def prediction_step( has_labels = "labels" in inputs inputs = self._prepare_inputs(inputs) - # XXX: adapt synced_gpus for fairscale as well # Priority (handled in generate): # non-`None` gen_kwargs > model.generation_config > default GenerationConfig() if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"): diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index e19bc854ff48d8..5bf29efffa8fc6 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -651,14 +651,6 @@ def number_of_arguments(func): return len(inspect.signature(func).parameters) -class ShardedDDPOption(ExplicitEnum): - SIMPLE = "simple" - ZERO_DP_2 = "zero_dp_2" - ZERO_DP_3 = "zero_dp_3" - OFFLOAD = "offload" - AUTO_WRAP = "auto_wrap" - - def find_executable_batch_size( function: callable = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False ): diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 7c016b15b2e648..635ab656ff699c 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -34,7 +34,6 @@ HubStrategy, IntervalStrategy, SchedulerType, - ShardedDDPOption, ) from .utils import ( ExplicitEnum, @@ -328,9 +327,9 @@ class TrainingArguments: fp16_backend (`str`, *optional*, defaults to `"auto"`): This argument is deprecated. Use `half_precision_backend` instead. half_precision_backend (`str`, *optional*, defaults to `"auto"`): - The backend to use for mixed precision training. Must be one of `"auto", "cuda_amp", "apex", "cpu_amp"`. - `"auto"` will use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices - will force the requested backend. + The backend to use for mixed precision training. Must be one of `"auto", "apex", "cpu_amp"`. `"auto"` will + use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices will force the + requested backend. bf16_full_eval (`bool`, *optional*, defaults to `False`): Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm metric values. This is an experimental API and it may change. @@ -410,21 +409,6 @@ class TrainingArguments: When resuming training, whether or not to skip the epochs and batches to get the data loading at the same stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step can take a long time) but will not yield the same results as the interrupted training would have. - sharded_ddp (`bool`, `str` or list of [`~trainer_utils.ShardedDDPOption`], *optional*, defaults to `''`): - Use Sharded DDP training from [FairScale](https://github.com/facebookresearch/fairscale) (in distributed - training only). This is an experimental feature. - - A list of options along the following: - - - `"simple"`: to use first instance of sharded DDP released by fairscale (`ShardedDDP`) similar to ZeRO-2. - - `"zero_dp_2"`: to use the second instance of sharded DPP released by fairscale (`FullyShardedDDP`) in - Zero-2 mode (with `reshard_after_forward=False`). - - `"zero_dp_3"`: to use the second instance of sharded DPP released by fairscale (`FullyShardedDDP`) in - Zero-3 mode (with `reshard_after_forward=True`). - - `"offload"`: to add ZeRO-offload (only compatible with `"zero_dp_2"` and `"zero_dp_3"`). - - If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty - list for `False` and `["simple"]` for `True`. fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `''`): Use PyTorch Distributed Parallel Training (in distributed training only). @@ -877,7 +861,7 @@ class TrainingArguments: default="auto", metadata={ "help": "The backend to be used for half precision.", - "choices": ["auto", "cuda_amp", "apex", "cpu_amp"], + "choices": ["auto", "apex", "cpu_amp"], }, ) bf16_full_eval: bool = field( @@ -996,17 +980,6 @@ class TrainingArguments: ) }, ) - sharded_ddp: Optional[Union[List[ShardedDDPOption], str]] = field( - default="", - metadata={ - "help": ( - "Whether or not to use sharded DDP training (in distributed training only). The base option should be" - " `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` like" - " this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or `zero_dp_3`" - " with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`." - ), - }, - ) fsdp: Optional[Union[List[FSDPOption], str]] = field( default="", metadata={ @@ -1154,7 +1127,7 @@ class TrainingArguments: default="auto", metadata={ "help": "Deprecated. Use half_precision_backend instead", - "choices": ["auto", "cuda_amp", "apex", "cpu_amp"], + "choices": ["auto", "apex", "cpu_amp"], }, ) push_to_hub_model_id: Optional[str] = field( @@ -1407,8 +1380,6 @@ def __post_init__(self): " `--half_precision_backend apex`: GPU bf16 is not supported by apex. Use" " `--half_precision_backend cuda_amp` instead" ) - if not (self.sharded_ddp == "" or not self.sharded_ddp): - raise ValueError("sharded_ddp is not supported with bf16") if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU: if self.evaluation_strategy == IntervalStrategy.NO: @@ -1508,7 +1479,7 @@ def __post_init__(self): # no need to assert on else # if training args is specified, it will override the one specified in the accelerate config - if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0: + if self.half_precision_backend != "apex": mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no") if self.fp16: mixed_precision_dtype = "fp16" @@ -1541,26 +1512,6 @@ def __post_init__(self): " during training" ) - if not (self.sharded_ddp == "" or not self.sharded_ddp): - warnings.warn( - "using `sharded_ddp` is deprecated and will be removed in version 4.33" - " of 🤗 Transformers. Use `fsdp` instead", - FutureWarning, - ) - if isinstance(self.sharded_ddp, bool): - self.sharded_ddp = "simple" if self.sharded_ddp else "" - if isinstance(self.sharded_ddp, str): - self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()] - if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]: - raise ValueError( - "`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or " - '`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.' - ) - elif len(self.sharded_ddp) > 1 and ShardedDDPOption.SIMPLE in self.sharded_ddp: - raise ValueError("`--sharded_ddp simple` is not compatible with any other option.") - elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp: - raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.") - if isinstance(self.fsdp, bool): self.fsdp = "full_shard" if self.fsdp else "" if isinstance(self.fsdp, str): diff --git a/tests/extended/test_trainer_ext.py b/tests/extended/test_trainer_ext.py index 7fd2fc9389ab25..831ffd5feedeb1 100644 --- a/tests/extended/test_trainer_ext.py +++ b/tests/extended/test_trainer_ext.py @@ -16,7 +16,6 @@ import os import re import sys -import unittest from pathlib import Path from typing import Tuple from unittest.mock import patch @@ -32,7 +31,6 @@ get_torch_dist_unique_port, require_apex, require_bitsandbytes, - require_fairscale, require_torch, require_torch_gpu, require_torch_multi_gpu, @@ -105,36 +103,6 @@ def test_run_seq2seq_dp(self): def test_run_seq2seq_ddp(self): self.run_seq2seq_quick(distributed=True) - # test --sharded_ddp w/o --fp16 - @unittest.skip("Requires an update of the env running those tests") - @require_torch_multi_gpu - @require_fairscale - def test_run_seq2seq_sharded_ddp(self): - self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple") - - # test --sharded_ddp w/ --fp16 - @unittest.skip("Requires an update of the env running those tests") - @require_torch_multi_gpu - @require_fairscale - def test_run_seq2seq_sharded_ddp_fp16(self): - self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16") - - # test --sharded_ddp zero_dp_2 w/o --fp16 - @unittest.skip("Requires an update of the env running those tests") - @require_torch_multi_gpu - @require_fairscale - def test_run_seq2seq_fully_sharded_ddp(self): - self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False) - - # test --sharded_ddp zero_dp_2 w/ --fp16 - @unittest.skip("Requires an update of the env running those tests") - @require_torch_multi_gpu - @require_fairscale - def test_run_seq2seq_fully_sharded_ddp_fp16(self): - self.run_seq2seq_quick( - distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False - ) - @require_apex @require_torch_gpu def test_run_seq2seq_apex(self):