Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make log_line_prefix_template Optional in Elastic Launcher for Backward Compatibility #2888

Merged
merged 3 commits into from
Jul 3, 2024

Conversation

yhna940
Copy link
Contributor

@yhna940 yhna940 commented Jun 25, 2024

What does this PR do?

This PR addresses an issue with the notebook_launcher in the accelerate library, where the inclusion of the log_line_prefix_template argument in the LaunchConfig instantiation leads to a TypeError in certain PyTorch versions. By making this argument optional, we ensure compatibility across all supported PyTorch versions.

Description

The notebook_launcher function in the accelerate library currently includes a log_line_prefix_template argument when creating a LaunchConfig instance. However, this argument is not supported in some versions of PyTorch, resulting in a TypeError. To maintain compatibility with all supported versions of PyTorch, this PR makes the inclusion of the log_line_prefix_template argument conditional.

Changes Made:

  • Added a log_line_prefix_template parameter to the notebook_launcher function.
  • Updated the notebook_launcher function to conditionally include the log_line_prefix_template argument in the LaunchConfig instantiation based on the PyTorch version.

Motivation and Context

Ensuring compatibility with a wider range of PyTorch versions is crucial for users who may not be able to upgrade to the latest versions. This change will allow users with different PyTorch versions to use the notebook_launcher without encountering errors.

Log / Env

  1. Log
    def notebook_launcher(
        function,
        args=(),
        num_processes=None,
        mixed_precision="no",
        use_port="29500",
        master_addr="127.0.0.1",
        node_rank=0,
        num_nodes=1,
        rdzv_backend="static",
        rdzv_endpoint="",
        rdzv_conf=None,
        rdzv_id="none",
        max_restarts=0,
        monitor_interval=0.1,
    ):
        """
        Launches a training function, using several processes or multiple nodes if it's possible in the current environment
        (TPU with multiple cores for instance).

        <Tip warning={true}>

        To use this function absolutely zero calls to a CUDA device must be made in the notebook session before calling. If
        any have been made, you will need to restart the notebook and make sure no cells use any CUDA capability.

        Setting `ACCELERATE_DEBUG_MODE="1"` in your environment will run a test before truly launching to ensure that none
        of those calls have been made.

        </Tip>

        Args:
            function (`Callable`):
                The training function to execute. If it accepts arguments, the first argument should be the index of the
                process run.
            args (`Tuple`):
                Tuple of arguments to pass to the function (it will receive `*args`).
            num_processes (`int`, *optional*):
                The number of processes to use for training. Will default to 8 in Colab/Kaggle if a TPU is available, to
                the number of GPUs available otherwise.
            mixed_precision (`str`, *optional*, defaults to `"no"`):
                If `fp16` or `bf16`, will use mixed precision training on multi-GPU.
            use_port (`str`, *optional*, defaults to `"29500"`):
                The port to use to communicate between processes when launching a multi-GPU training.
            master_addr (`str`, *optional*, defaults to `"127.0.0.1"`):
                The address to use for communication between processes.
            node_rank (`int`, *optional*, defaults to 0):
                The rank of the current node.
            num_nodes (`int`, *optional*, defaults to 1):
                The number of nodes to use for training.
            rdzv_backend (`str`, *optional*, defaults to `"static"`):
                The rendezvous method to use, such as 'static' (the default) or 'c10d'
            rdzv_endpoint (`str`, *optional*, defaults to `""`):
                The endpoint of the rdzv sync. storage.
            rdzv_conf (`Dict`, *optional*, defaults to `None`):
                Additional rendezvous configuration.
            rdzv_id (`str`, *optional*, defaults to `"none"`):
                The unique run id of the job.
            max_restarts (`int`, *optional*, defaults to 0):
                The maximum amount of restarts that elastic agent will conduct on workers before failure.
            monitor_interval (`float`, *optional*, defaults to 0.1):
                The interval in seconds that is used by the elastic_agent as a period of monitoring workers.

        Example:

        ```python
        # Assume this is defined in a Jupyter Notebook on an instance with two GPUs
        from accelerate import notebook_launcher


        def train(*args):
            # Your training function here
            ...


        notebook_launcher(train, args=(arg1, arg2), num_processes=2, mixed_precision="fp16")
        ```
        """
        # Are we in a google colab or a Kaggle Kernel?
        in_colab = False
        in_kaggle = False
        if any(key.startswith("KAGGLE") for key in os.environ.keys()):
            in_kaggle = True
        elif "IPython" in sys.modules:
            in_colab = "google.colab" in str(sys.modules["IPython"].get_ipython())

        try:
            mixed_precision = PrecisionType(mixed_precision.lower())
        except ValueError:
            raise ValueError(
                f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
            )

        if (in_colab or in_kaggle) and (os.environ.get("TPU_NAME", None) is not None):
            # TPU launch
            import torch_xla.distributed.xla_multiprocessing as xmp

            if len(AcceleratorState._shared_state) > 0:
                raise ValueError(
                    "To train on TPU in Colab or Kaggle Kernel, the `Accelerator` should only be initialized inside "
                    "your training function. Restart your notebook and make sure no cells initializes an "
                    "`Accelerator`."
                )
            if num_processes is None:
                num_processes = 8

            launcher = PrepareForLaunch(function, distributed_type="TPU")
            print(f"Launching a training on {num_processes} TPU cores.")
            xmp.spawn(launcher, args=args, nprocs=num_processes, start_method="fork")
        elif in_colab and get_gpu_info()[1] < 2:
            # No need for a distributed launch otherwise as it's either CPU or one GPU.
            if torch.cuda.is_available():
                print("Launching training on one GPU.")
            else:
                print("Launching training on one CPU.")
            function(*args)
        else:
            if num_processes is None:
                raise ValueError(
                    "You have to specify the number of GPUs you would like to use, add `num_processes=...` to your call."
                )
            if node_rank >= num_nodes:
                raise ValueError("The node_rank must be less than the number of nodes.")
            if num_processes > 1:
                # Multi-GPU launch
                from torch.distributed.launcher.api import LaunchConfig, elastic_launch
                from torch.multiprocessing import start_processes
                from torch.multiprocessing.spawn import ProcessRaisedException

                if len(AcceleratorState._shared_state) > 0:
                    raise ValueError(
                        "To launch a multi-GPU training from your notebook, the `Accelerator` should only be initialized "
                        "inside your training function. Restart your notebook and make sure no cells initializes an "
                        "`Accelerator`."
                    )
                # Check for specific libraries known to initialize CUDA that users constantly use
                problematic_imports = are_libraries_initialized("bitsandbytes")
                if len(problematic_imports) > 0:
                    err = (
                        "Could not start distributed process. Libraries known to initialize CUDA upon import have been "
                        "imported already. Please keep these imports inside your training function to try and help with this:"
                    )
                    for lib_name in problematic_imports:
                        err += f"\n\t* `{lib_name}`"
                    raise RuntimeError(err)

                patched_env = dict(
                    nproc=num_processes,
                    node_rank=node_rank,
                    world_size=num_nodes * num_processes,
                    master_addr=master_addr,
                    master_port=use_port,
                    mixed_precision=mixed_precision,
                )

                # Check for CUDA P2P and IB issues
                if not check_cuda_p2p_ib_support():
                    patched_env["nccl_p2p_disable"] = "1"
                    patched_env["nccl_ib_disable"] = "1"

                # torch.distributed will expect a few environment variable to be here. We set the ones common to each
                # process here (the other ones will be set be the launcher).
                with patch_environment(**patched_env):
                    # First dummy launch
                    if os.environ.get("ACCELERATE_DEBUG_MODE", "false").lower() == "true":
                        launcher = PrepareForLaunch(test_launch, distributed_type="MULTI_GPU")
                        try:
                            start_processes(launcher, args=(), nprocs=num_processes, start_method="fork")
                        except ProcessRaisedException as e:
                            err = "An issue was found when verifying a stable environment for the notebook launcher."
                            if "Cannot re-initialize CUDA in forked subprocess" in e.args[0]:
                                raise RuntimeError(
                                    f"{err}"
                                    "This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
                                    "Please review your imports and test them when running the `notebook_launcher()` to identify "
                                    "which one is problematic and causing CUDA to be initialized."
                                ) from e
                            else:
                                raise RuntimeError(f"{err} The following error was raised: {e}") from e
                    # Now the actual launch
                    launcher = PrepareForLaunch(function, distributed_type="MULTI_GPU")
                    print(f"Launching training on {num_processes} GPUs.")
                    try:
                        if rdzv_conf is None:
                            rdzv_conf = {}
                        if rdzv_backend == "static":
                            rdzv_conf["rank"] = node_rank
                            if not rdzv_endpoint:
                                rdzv_endpoint = f"{master_addr}:{use_port}"
>                       launch_config = LaunchConfig(
                            min_nodes=num_nodes,
                            max_nodes=num_nodes,
                            nproc_per_node=num_processes,
                            run_id=rdzv_id,
                            rdzv_endpoint=rdzv_endpoint,
                            rdzv_backend=rdzv_backend,
                            rdzv_configs=rdzv_conf,
                            max_restarts=max_restarts,
                            monitor_interval=monitor_interval,
                            start_method="fork",
                            log_line_prefix_template=os.environ.get("TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE"),
                        )
E                       TypeError: LaunchConfig.__init__() got an unexpected keyword argument 'log_line_prefix_template'
  1. Accelerate Version
Name: accelerate
Version: 0.31.0
Summary: Accelerate
...
  1. Torch Version
Name: torch
Version: 2.0.1+cu118
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team

Torch Version Comparison

Torch Version 1.10

@dataclass
class LaunchConfig:
    min_nodes: int
    max_nodes: int
    nproc_per_node: int
    run_id: str = ""
    role: str = "default_role"
    rdzv_endpoint: str = ""
    rdzv_backend: str = "etcd"
    rdzv_configs: Dict[str, Any] = field(default_factory=dict)
    rdzv_timeout: int = -1
    max_restarts: int = 3
    monitor_interval: float = 30
    start_method: str = "spawn"
    log_dir: Optional[str] = None
    redirects: Union[Std, Dict[int, Std]] = Std.NONE
    tee: Union[Std, Dict[int, Std]] = Std.NONE
    metrics_cfg: Dict[str, str] = field(default_factory=dict)

Torch Version 2.24

@dataclass
class LaunchConfig:
    min_nodes: int
    max_nodes: int
    nproc_per_node: int
    logs_specs: Optional[LogsSpecs] = None
    run_id: str = ""
    role: str = "default_role"
    rdzv_endpoint: str = ""
    rdzv_backend: str = "etcd"
    rdzv_configs: Dict[str, Any] = field(default_factory=dict)
    rdzv_timeout: int = -1
    max_restarts: int = 3
    monitor_interval: float = 0.1
    start_method: str = "spawn"
    log_line_prefix_template: Optional[str] = None
    metrics_cfg: Dict[str, str] = field(default_factory=dict)
    local_addr: Optional[str] = None

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • Did you write any new necessary tests?

@yhna940 yhna940 changed the title [Hotfix] Remove log_line_prefix_template argument from LaunchConfig to ensure compatibility with supported PyTorch versions Remove log_line_prefix_template argument from LaunchConfig to ensure compatibility with supported PyTorch versions Jun 25, 2024
@BenjaminBossan
Copy link
Member

Thanks for this PR. I wonder if we could not check the PyTorch version and pass the argument only if it's supported? Alternatively, we could check the signature for this argument, but I'm not a big fan of that.

@yhna940
Copy link
Contributor Author

yhna940 commented Jun 25, 2024

Thanks for your review :) It looks like the minimum torch version that supports the target argument is 2.2.0, so I added logic to check for a specific version.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change. Just one small suggestion from my side.

@@ -235,7 +237,11 @@ def train(*args):
monitor_interval=monitor_interval,
start_method="fork",
)
elastic_launch(config=launch_config, entrypoint=function)(*args)
if is_torch_version(">=", ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION):
launch_config_kwargs["log_line_prefix_template"] = os.environ.get(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of trying to get this from the env vars, I think it's better to import the value directly. I don't think there is ever a reason for a user to override this via env vars, is there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review. I've deleted the logic that takes that value as an environment variable :)

@yhna940 yhna940 requested a review from BenjaminBossan June 25, 2024 12:23
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM. Let's wait for another review before merging.

@yhna940 yhna940 changed the title Remove log_line_prefix_template argument from LaunchConfig to ensure compatibility with supported PyTorch versions Make log_line_prefix_template Optional in Elastic Launcher for Backward Compatibility Jul 1, 2024
@yhna940
Copy link
Contributor Author

yhna940 commented Jul 1, 2024

Hi there,

I wanted to kindly ask about the status of this PR. I don't mean to rush, but this PR is crucial for resolving a bug in my project. Your review and approval would be greatly appreciated :)

Thank you for your time and assistance!

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@muellerzr muellerzr merged commit 404510a into huggingface:main Jul 3, 2024
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants