Skip to content

🐛 [Bug] Multi-GPU model moved to single GPU #2269

Closed
@austinapatel

Description

@austinapatel

Bug Description

I'm experimenting using TorchTRT with a model partitioned across two GPUs using pipeline parallelism techniques. Half of my network is on GPU0 and the second half is on GPU1. When executing the model in PyTorch eager mode, I see kernels for each layer executing on their assigned GPU as expected. When I compile my network with TorchTRT, the network is moved to only one of the GPUs and the network is then executed on that device, rather than being split across GPUs. This limits the ability of being able to use TorchTRT with very large models that don't fit within the memory of a single GPU.

To Reproduce

Steps to reproduce the behavior:

  1. Run the profiling command: nsys profile --trace cuda,nvtx --sample cpu --force-overwrite true --output profiling_results/tmp --gpu-metrics-device=all --gpu-metrics-frequency=20000 python torchtrt_multigpu_issue.py

torchtrt_multigpu_issue.py:

import torch
from torch import nn
from torch.nn import functional as F
import torch_tensorrt

"""NETWORK"""
class SimpleNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 100).to('cuda:0')
        self.fc2 = nn.Linear(100, 5).to('cuda:0')
        self.mha = nn.MultiheadAttention(embed_dim=5, num_heads=1).to('cuda:1')
    
    def forward(self, x):
        x = x.to('cuda:0')
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = x.to('cuda:1')
        x, _ = self.mha(x, x, x)
        x = F.relu(x)
        return x

"""PROFILING UTILITIES"""
def timed(fn, msg):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    torch.cuda.nvtx.mark(f'timed start: {msg}')
    start.record()
    result = fn()
    end.record()
    torch.cuda.nvtx.mark(f'timed end: {msg}')
    torch.cuda.synchronize()
    return result, start.elapsed_time(end)


def profile(method, nvtx_msg=None):
    # profiling
    print('Profiling with torch NVTX markers enabled')
    torch.cuda.cudart().cudaProfilerStart()
    with torch.autograd.profiler.emit_nvtx(record_shapes=True):
        if nvtx_msg:
            torch.cuda.nvtx.mark(f'profile start: {nvtx_msg}')
        result = method()
        if nvtx_msg:
            torch.cuda.nvtx.mark(f'profile end: {nvtx_msg}')
    torch.cuda.cudart().cudaProfilerStop()

    return result


def run_two_models(model1, model2, inp):
    torch.cuda.synchronize()

    output1, time1 = timed(lambda: model1(inp), f'profiling model1')
    output2, time2 = timed(lambda: model2(inp), f'profiling model2')

    return time1, time2

"""MAIN"""
if __name__ == '__main__':
    inp = torch.ones((10,10), device='cuda:0')

    # Eager
    eager_model = SimpleNetwork()
    eager_model.eval()
    print(f'Calling eager model')
    eager_model(inp)

    # TorchTRT
    torch._dynamo.reset()
    trt_model = torch_tensorrt.compile(eager_model, ir="torch_compile", inputs=inp, use_python_runtime=False)
    print(f'Calling trt_model')
    trt_model(inp)

    # Profiling
    time_eager, time_trt = profile(lambda: run_two_models(eager_model, trt_model, inp))
    print(f'Eager time: {time_eager}ms\nTime TRT: {time_trt}ms')

Output:

INFO:torch_tensorrt.dynamo.utils:Using Default Torch-TRT Runtime (as requested by user)
INFO:torch_tensorrt.dynamo.utils:Compilation Settings: CompilationSettings(precision=torch.float32, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_long_and_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False)

INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.036864
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:12.406040
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 10752 bytes of Memory
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.002983
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.182259
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 1024 bytes of Memory
WARNING: [Torch-TensorRT] - Input 5 of engine _run_on_acc_1_engine was found to be on cuda:1 but should be on cuda:0. This tensor is being moved by the runtime but for performance considerations, ensure your inputs are all on GPU and open an issue here (https://github.com/pytorch/TensorRT/issues) if this warning persists.
WARNING: [Torch-TensorRT] - Input 6 of engine _run_on_acc_1_engine was found to be on cuda:1 but should be on cuda:0. This tensor is being moved by the runtime but for performance considerations, ensure your inputs are all on GPU and open an issue here (https://github.com/pytorch/TensorRT/issues) if this warning persists.
WARNING: [Torch-TensorRT] - Input 7 of engine _run_on_acc_1_engine was found to be on cuda:1 but should be on cuda:0. This tensor is being moved by the runtime but for performance considerations, ensure your inputs are all on GPU and open an issue here (https://github.com/pytorch/TensorRT/issues) if this warning persists.
WARNING: [Torch-TensorRT] - Input 8 of engine _run_on_acc_1_engine was found to be on cuda:1 but should be on cuda:0. This tensor is being moved by the runtime but for performance considerations, ensure your inputs are all on GPU and open an issue here (https://github.com/pytorch/TensorRT/issues) if this warning persists.
WARNING: [Torch-TensorRT] - Input 9 of engine _run_on_acc_1_engine was found to be on cuda:1 but should be on cuda:0. This tensor is being moved by the runtime but for performance considerations, ensure your inputs are all on GPU and open an issue here (https://github.com/pytorch/TensorRT/issues) if this warning persists.
WARNING: [Torch-TensorRT] - Input 10 of engine _run_on_acc_1_engine was found to be on cuda:1 but should be on cuda:0. This tensor is being moved by the runtime but for performance considerations, ensure your inputs are all on GPU and open an issue here (https://github.com/pytorch/TensorRT/issues) if this warning persists.
WARNING: [Torch-TensorRT] - Input 11 of engine _run_on_acc_1_engine was found to be on cuda:1 but should be on cuda:0. This tensor is being moved by the runtime but for performance considerations, ensure your inputs are all on GPU and open an issue here (https://github.com/pytorch/TensorRT/issues) if this warning persists.
WARNING: [Torch-TensorRT] - Input 12 of engine _run_on_acc_1_engine was found to be on cuda:1 but should be on cuda:0. This tensor is being moved by the runtime but for performance considerations, ensure your inputs are all on GPU and open an issue here (https://github.com/pytorch/TensorRT/issues) if this warning persists.
WARNING: [Torch-TensorRT] - Input 5 of engine _run_on_acc_1_engine was found to be on cuda:1 but should be on cuda:0. This tensor is being moved by the runtime but for performance considerations, ensure your inputs are all on GPU and open an issue here (https://github.com/pytorch/TensorRT/issues) if this warning persists.
WARNING: [Torch-TensorRT] - Input 6 of engine _run_on_acc_1_engine was found to be on cuda:1 but should be on cuda:0. This tensor is being moved by the runtime but for performance considerations, ensure your inputs are all on GPU and open an issue here (https://github.com/pytorch/TensorRT/issues) if this warning persists.
WARNING: [Torch-TensorRT] - Input 7 of engine _run_on_acc_1_engine was found to be on cuda:1 but should be on cuda:0. This tensor is being moved by the runtime but for performance considerations, ensure your inputs are all on GPU and open an issue here (https://github.com/pytorch/TensorRT/issues) if this warning persists.
WARNING: [Torch-TensorRT] - Input 8 of engine _run_on_acc_1_engine was found to be on cuda:1 but should be on cuda:0. This tensor is being moved by the runtime but for performance considerations, ensure your inputs are all on GPU and open an issue here (https://github.com/pytorch/TensorRT/issues) if this warning persists.
WARNING: [Torch-TensorRT] - Input 9 of engine _run_on_acc_1_engine was found to be on cuda:1 but should be on cuda:0. This tensor is being moved by the runtime but for performance considerations, ensure your inputs are all on GPU and open an issue here (https://github.com/pytorch/TensorRT/issues) if this warning persists.
WARNING: [Torch-TensorRT] - Input 10 of engine _run_on_acc_1_engine was found to be on cuda:1 but should be on cuda:0. This tensor is being moved by the runtime but for performance considerations, ensure your inputs are all on GPU and open an issue here (https://github.com/pytorch/TensorRT/issues) if this warning persists.
WARNING: [Torch-TensorRT] - Input 11 of engine _run_on_acc_1_engine was found to be on cuda:1 but should be on cuda:0. This tensor is being moved by the runtime but for performance considerations, ensure your inputs are all on GPU and open an issue here (https://github.com/pytorch/TensorRT/issues) if this warning persists.
WARNING: [Torch-TensorRT] - Input 12 of engine _run_on_acc_1_engine was found to be on cuda:1 but should be on cuda:0. This tensor is being moved by the runtime but for performance considerations, ensure your inputs are all on GPU and open an issue here (https://github.com/pytorch/TensorRT/issues) if this warning persists.
Calling eager model
Calling trt_model
Profiling with torch NVTX markers enabled
Eager time: 1.8095359802246094ms
Time TRT: 3.959264039993286ms

Note if you are TorchTRT debug build you might get an error about profiling already being enabled. In this case you can either:

  1. Replace time_eager, time_trt = profile(lambda: run_two_models(eager_model, trt_model, inp)) with time_eager, time_trt = run_two_models(eager_model, trt_model, inp)
  2. Change
    cmd.append("--compilation_mode=dbg")
    from dbg to opt and re-build TorchTRT.

Expected behavior

I was hoping that separate TRT engines would be created and executed on each GPU, rather than having everything moved to a single GPU and have engines executed there. I'm curious to know if there is a workaround for this, or if there is a more fundamental limitation that could prevent this from being possible. When using PyTorch Inductor backend, I am able to create compiled components that leverage both GPUs, which is the desired behavior I expected.

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): commit 65c6494ce3107d33b27fdb1630ac7982f8649382 built from source
  • PyTorch Version (e.g. 1.0): 2.1.0a0+4136153 (commit 7682252cac9ed31055d2ae950cf6942dd311da73)
  • CPU Architecture:
  • OS (e.g., Linux): Ubuntu 22.04.2 LTS
  • How you installed PyTorch (conda, pip, libtorch, source): source
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.10.6
  • CUDA version: 12.1
  • GPU models and configuration: NVIDIA A100 80GB PCIe
  • Any other relevant information:

Additional context

NSYS trace for using TorchTRT. Note that all layers are copied to one GPU and all engines execute on that GPU.
multigpu_trt_annotated

NSYS trace for using Torch Inductor. Note that the model parallelism is followed and both GPUs are utilized.
multigpu_inductor_annotated

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions