Skip to content

Multi-GPU-Arch pre-compilation of operators not supported #7863

@Flamefire

Description

@Flamefire

Part of host code relies on __CUDA_ARCH__ which must not be done. When compiling for multiple GPU architectures that can cause different evaluations leading to ODR violations, or the CUDA equivalent of that.

I noticed this in a test: pytest DeepSpeed/tests/unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py -k 'test_DS4Sci_EvoformerAttention[tensor_shape0-dtype0]'

Depending on the order of CUDA archs in TORCH_CUDA_ARCH_LIST the test fails with cudaInvalidDeviceFunc (in CUDA 12.8 or later it is cudaErrorInvalidResourceHandle instead)

Reproducible by compiling DeepSpeed with pre-compiled operators:
TORCH_CUDA_ARCH_LIST='8.0;7.0' DS_BUILD_OPS=0 DS_BUILD_EVOFORMER_ATTN=1 pip install .

Then running the above test on A100 (CUDA CC 8.0)
Swapping the archs to ``TORCH_CUDA_ARCH_LIST='7.0;8.0'` makes it succeed. However I expect it will then fail on SM 7.0 devices.

I traced this to attention_impl_template and attention_back_impl_template which both have the same issue:
This declaration: typename std::enable_if<!CheckArch<arch, scalar_t>::value>::type attention_impl_template

CheckArch::value depends on the value of __CUDA_ARCH__:

static constexpr bool compiler_cc = arch::kMinComputeCapability * 10 <= __CUDA_ARCH__;

So depending on which compile iteration is done it will yield once true and once false. So when linking the multiple object files together the calling function (attention_impl) will have differing implementations -> ODR violation.
The order of -gencode flags to nvcc seems to also affect the ordering in the fatbin file which could be another reason why it may or may not find the correct function.

I made a small reproducer to show the issue:

#include <cuda_runtime.h>
#include <cstdio>
#include <cstdlib>

template<typename T>
__global__ void test_kernel(T *out) {}

template<typename arch>
struct Check {
  static constexpr bool value =
#if defined(__CUDA_ARCH__)
    __CUDA_ARCH__ >= 800;
#else
    true;
#endif
};

template <typename arch>
typename std::enable_if<!Check<arch>::value>::type foo() {}
template <typename arch>
typename std::enable_if<Check<arch>::value>::type foo()
{
  cudaFuncAttributes attr;
  auto func = test_kernel<float>;
  auto e = cudaFuncGetAttributes(&attr, func);
  if (e != cudaSuccess) {
    fprintf(stderr, "%s\n", cudaGetErrorString(e));
    std::exit(1);
  }
}

int main() {
    foo<bool>();
    printf("SUCCESS\n");
    return 0;
}
  • nvcc repro.cu -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_70,code=sm_70 -o repro_80_first && ./repro_80_first FAILS
  • nvcc repro.cu -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_80,code=sm_80 -o repro_70_first && ./repro_70_first WORKS

My understanding is that __CUDA_ARCH__ is meant to be used in __host__ __device__ or __device__ code and must not affect host code.
I only isolated this single issue but there might be other instances where __CUDA_ARCH__ is used incorrectly, especially through secondary uses like #if __CUDA_ARCH__ --> #define FOO and later in host code #ifdef FOO

For JIT compiled operations this isn't usually an issue because there is usually only one type of GPU available on the current system so only a single -gencode will be passed:

if self.jit_mode:
# Compile for underlying architectures since we know those at runtime
for i in range(torch.cuda.device_count()):
CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i)
cc = f"{CC_MAJOR}.{CC_MINOR}"
if cc not in ccs:
ccs.append(cc)
ccs = sorted(ccs)
ccs[-1] += '+PTX'

Although it may cause problems.

# Torch will try and apply whatever CCs are in the arch list at compile time,
# we have already set the intended targets ourselves we know that will be
# needed at runtime. This prevents CC collisions such as multiple __half
# implementations. Stash arch list to reset after build.
torch_arch_list = None
if "TORCH_CUDA_ARCH_LIST" in os.environ:
torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
os.environ["TORCH_CUDA_ARCH_LIST"] = ""

Unsetting TORCH_CUDA_ARCH_LIST may cause PyTorch to add all archs:

/tmp/lib/python3.11/site-packages/torch/utils/cpp_extension.py:1967: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.

And indeed: The JIT nvcc commandline shows multiple -gencode=arch=compute_80,code=compute_80 flags. I counted 4, 2 after --threads=8, likely added by DeepSpeed.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions