Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[<Ray component: Train>] Ray Train fails for AMD multi-gpu: Invalid Device Ordinal. #49260

Closed
amorinConnor opened this issue Dec 13, 2024 · 7 comments
Assignees
Labels
bug Something that is supposed to be working; but isn't train Ray Train Related Issue triage Needs triage (eg: priority, bug/not-bug, and owning component)

Comments

@amorinConnor
Copy link

What happened + What you expected to happen

I am trying to run some basic torch training code on an AMD machine with 4 GPUs. For multi-gpu training (num workers >1 below) Ray fails with the following error:

(RayTrainWorker pid=1262281) Setting up process group for: env:// [rank=0, world_size=4]
(TorchTrainer pid=1262101) Started distributed worker processes:
(TorchTrainer pid=1262101) - (node_id=fb34f40133edf0f88d7ef5940e8c6ea1b0ec24528bc26b55fbbf4550, ip=192.168.193.244, pid=1262281) world_rank=0, local_rank=0, node_rank=0
(TorchTrainer pid=1262101) - (node_id=fb34f40133edf0f88d7ef5940e8c6ea1b0ec24528bc26b55fbbf4550, ip=192.168.193.244, pid=1262283) world_rank=1, local_rank=1, node_rank=0
(TorchTrainer pid=1262101) - (node_id=fb34f40133edf0f88d7ef5940e8c6ea1b0ec24528bc26b55fbbf4550, ip=192.168.193.244, pid=1262284) world_rank=2, local_rank=2, node_rank=0
(TorchTrainer pid=1262101) - (node_id=fb34f40133edf0f88d7ef5940e8c6ea1b0ec24528bc26b55fbbf4550, ip=192.168.193.244, pid=1262282) world_rank=3, local_rank=3, node_rank=0
(RayTrainWorker pid=1262281) Moving model to device: cuda:0
2024-12-13 11:12:48,791	ERROR tune_controller.py:1331 -- Trial task failed for trial TorchTrainer_35947_00000
Traceback (most recent call last):
  File "/usr/WS2/amorin1/venv/python3_9_rocm_6_1/lib/python3.9/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/usr/WS2/amorin1/venv/python3_9_rocm_6_1/lib/python3.9/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/WS2/amorin1/venv/python3_9_rocm_6_1/lib/python3.9/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/WS2/amorin1/venv/python3_9_rocm_6_1/lib/python3.9/site-packages/ray/_private/worker.py", line 2756, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/WS2/amorin1/venv/python3_9_rocm_6_1/lib/python3.9/site-packages/ray/_private/worker.py", line 906, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::_Inner.train() (pid=1262101, ip=192.168.193.244, actor_id=4637f52479e6eee701d5044302000000, repr=TorchTrainer)
  File "/usr/WS2/amorin1/venv/python3_9_rocm_6_1/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 331, in train
    raise skipped from exception_cause(skipped)
  File "/usr/WS2/amorin1/venv/python3_9_rocm_6_1/lib/python3.9/site-packages/ray/train/_internal/utils.py", line 57, in check_for_failure
    ray.get(object_ref)
ray.exceptions.RayTaskError(RuntimeError): ray::_RayTrainWorker__execute.get_next() (pid=1262282, ip=192.168.193.244, actor_id=9c0d3404f1d29f25a32f1b5002000000, repr=<ray.train._internal.worker_group.RayTrainWorker object at 0x152edf3790d0>)
  File "/usr/WS2/amorin1/venv/python3_9_rocm_6_1/lib/python3.9/site-packages/ray/train/_internal/worker_group.py", line 33, in __execute
    raise skipped from exception_cause(skipped)
  File "/usr/WS2/amorin1/venv/python3_9_rocm_6_1/lib/python3.9/site-packages/ray/train/_internal/utils.py", line 196, in train_fn
    with train_func_context():
  File "/usr/WS2/amorin1/venv/python3_9_rocm_6_1/lib/python3.9/site-packages/ray/train/torch/config.py", line 27, in __enter__
    torch.cuda.set_device(device)
  File "/usr/WS2/amorin1/venv/python3_9_rocm_6_1/lib/python3.9/site-packages/torch/cuda/__init__.py", line 478, in set_device
    torch._C._cuda_setDevice(device)
RuntimeError: HIP error: invalid device ordinal
HIP kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing AMD_SERIALIZE_KERNEL=3
Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions.

Training errored after 0 iterations at 2024-12-13 11:12:48. Total running time: 14s

I verified some test code after getting help on the ray forums: https://discuss.ray.io/t/torchtrainer-fails-rocm-multi-gpu-invalid-device-ordinal/21041

import torch
import os

import ray
from ray.train.torch import get_device

os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
ray.get_gpu_ids() == [2]
torch.cuda.is_available() == True
get_device() == torch.device("cuda:0")
print("Pass 1")

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
ray.get_gpu_ids() == [2]
torch.cuda.is_available() == True
get_device() == torch.device("cuda:2")
print("Pass 2")

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
ray.get_gpu_ids() == [2,3]
torch.cuda.is_available() == True
get_device() == torch.device("cuda:2")
print("Pass 3")

model = torch.nn.Linear(in_features=1, out_features=1)
model.to(ray.train.torch.get_device())
print("Pass 4")

Versions / Dependencies

ray 2.40.0
torch 2.5.0+rocm6.1
Python 3.9.2
Red Hat 10.3.1

Note this also fails on rocm 6.2

Reproduction script

import os
import tempfile

import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel

import ray
from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer

# If using GPUs, set this to True.
use_gpu = True
# Number of processes to run training on.
num_workers = 4
# del os.environ['OMP_PLACES']
# del os.environ['OMP_PROC_BIND']
# Define your network structure.
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.layer1 = nn.Linear(1, 32)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(32, 1)

    def forward(self, input):
        return self.layer2(self.relu(self.layer1(input)))

# Training loop.
def train_loop_per_worker(config):

    # Read configurations.
    lr = config["lr"]
    batch_size = config["batch_size"]
    num_epochs = config["num_epochs"]

    # Fetch training dataset.
    train_dataset_shard = ray.train.get_dataset_shard("train")

    # Instantiate and prepare model for training.
    model = NeuralNetwork()
    model = ray.train.torch.prepare_model(model)

    # Define loss and optimizer.
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    # Create data loader.
    dataloader = train_dataset_shard.iter_torch_batches(
        batch_size=batch_size, dtypes=torch.float
    )

    # Train multiple epochs.
    for epoch in range(num_epochs):

        # Train epoch.
        for batch in dataloader:
            output = model(batch["input"])
            loss = loss_fn(output, batch["label"])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Create checkpoint.
        base_model = (model.module
            if isinstance(model, DistributedDataParallel) else model)
        checkpoint_dir = tempfile.mkdtemp()
        torch.save(
            {"model_state_dict": base_model.state_dict()},
            os.path.join(checkpoint_dir, "model.pt"),
        )
        checkpoint = Checkpoint.from_directory(checkpoint_dir)

        # Report metrics and checkpoint.
        ray.train.report({"loss": loss.item()}, checkpoint=checkpoint)


# Define configurations.
train_loop_config = {"num_epochs": 20, "lr": 0.01, "batch_size": 32}
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1))

# Define datasets.
train_dataset = ray.data.from_items(
    [{"input": [x], "label": [2 * x + 1]} for x in range(2000)]
)
datasets = {"train": train_dataset}

# Initialize the Trainer.
trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    run_config=run_config,
    datasets=datasets
)

# Train the model.
result = trainer.fit()

# Inspect the results.

Issue Severity

High: It blocks me from completing my task.

@amorinConnor amorinConnor added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Dec 13, 2024
@AVSuni
Copy link

AVSuni commented Dec 13, 2024

I'm getting the same error with 2.40.0 and torch 2.5.1 and Rocm 6.2. Single GPU runs work fine, but all multi-GPU runs give this error. The bug affects every multi-GPU TorchTrainer job on AMD. The same code seems to run fine on Nvidia GPUs.

@hongpeng-guo hongpeng-guo self-assigned this Dec 14, 2024
@AVSuni
Copy link

AVSuni commented Dec 15, 2024

This error stems from torch trying to set a device based on CUDA_VISIBLE_DEVICES, but ROCR_VISIBLE_DEVICES only see a single device (i.e. device masking has already been done). A temporary fix is to set os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('ROCR_VISIBLE_DEVICES') in ray/train/torch/config.py as below

class TorchConfigContextManager:
    def __enter__(self):
        # Set default cuda device
        os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('ROCR_VISIBLE_DEVICES')
        if torch.cuda.is_available():
            device = ray.train.torch.get_device()
            if device.type == "cuda":
                print(f"ROCR_VISIBLE_DEVICES: {os.environ['ROCR_VISIBLE_DEVICES']}")
                print(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}")
                print(f"setting device: {device}")
                torch.cuda.set_device(device)

If my kids allow me, I will open a PR with a better fix tonight unless @hongpeng-guo beats me to it

@jcotant1 jcotant1 added the train Ray Train Related Issue label Dec 16, 2024
@amorinConnor
Copy link
Author

amorinConnor commented Dec 16, 2024

This error stems from torch trying to set a device based on CUDA_VISIBLE_DEVICES, but ROCR_VISIBLE_DEVICES only see a single device (i.e. device masking has already been done). A temporary fix is to set os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('ROCR_VISIBLE_DEVICES') in ray/train/torch/config.py as below

class TorchConfigContextManager:
    def __enter__(self):
        # Set default cuda device
        os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('ROCR_VISIBLE_DEVICES')
        if torch.cuda.is_available():
            device = ray.train.torch.get_device()
            if device.type == "cuda":
                print(f"ROCR_VISIBLE_DEVICES: {os.environ['ROCR_VISIBLE_DEVICES']}")
                print(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}")
                print(f"setting device: {device}")
                torch.cuda.set_device(device)

If my kids allow me, I will open a PR with a better fix tonight unless @hongpeng-guo beats me to it

I noticed this as well, but it's not enough to run the above training script for me unfortunately. I no longer hit the invalid device ordinal but instead get a much less informative error:

(RayTrainWorker pid=678735) Moving model to device: cuda:0 [repeated 3x across cluster]
(RayTrainWorker pid=678733) *** SIGSEGV received at time=1734370476 on cpu 120 ***
(RayTrainWorker pid=678733) PC: @     0x1555540a1912  (unknown)  __memmove_avx512_unaligned_erms
(RayTrainWorker pid=678733)     @     0x155554b35d10  (unknown)  (unknown)
(RayTrainWorker pid=678733) [2024-12-16 09:34:36,710 E 678733 679677] logging.cc:447: *** SIGSEGV received at time=1734370476 on cpu 120 ***
(RayTrainWorker pid=678733) [2024-12-16 09:34:36,710 E 678733 679677] logging.cc:447: PC: @     0x1555540a1912  (unknown)  __memmove_avx512_unaligned_erms
(RayTrainWorker pid=678733) [2024-12-16 09:34:36,710 E 678733 679677] logging.cc:447:     @     0x155554b35d10  (unknown)  (unknown)
(RayTrainWorker pid=678733) Fatal Python error: Segmentation fault
(RayTrainWorker pid=678733)
(RayTrainWorker pid=678736)
(RayTrainWorker pid=678734)
(RayTrainWorker pid=678735)
(raylet) A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffffc46bc3b30b8e8071e9290fdf01000000 Worker ID: df8b27019ed04b36b05916bd036496ca67a825b36c5bcc729cd30e1f Node ID: 9733fdbbc515603f523155fb8ffbb1bee51da955ce895fe71f23c45e Worker IP address: 192.168.192.151 Worker port: 10197 Worker PID: 678736 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
(TorchTrainer pid=678554) Worker 3 has failed.

I'm not sure yet if the above change is actually a solution, or if this is an entirely separate problem but it at least still appears to be cuda device related.

@AVSuni
Copy link

AVSuni commented Dec 16, 2024

@amorinConnor that looks like a separate problem, unrelated to GPU indexing/viability/communication

@hongpeng-guo
Copy link
Contributor

If my kids allow me, I will open a PR with a better fix tonight unless @hongpeng-guo beats me to it

@AVSuni Thank you so much for helping on this problem. Feel free to start a PR, I can help with it later. 👍

@amorinConnor
Copy link
Author

@AVSuni Have you confirmed this works on an AMD multi gpu setup? For me personally the code fails when wrapping the model in DistributedDataParallel.

@hongpeng-guo
Copy link
Contributor

Close this Issue for now. Solved by #49346

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't train Ray Train Related Issue triage Needs triage (eg: priority, bug/not-bug, and owning component)
Projects
None yet
Development

No branches or pull requests

4 participants