Description
Describe the issue
I'm trying to set up IPEX on a system with 8 PVC tiles and am having difficulty getting things working. Right now, I'm just trying to run some sanity tests to ensure things work. A basic P2P test is failing.
Steps Taken So Far
- I installed IPEX using the instructions linked in the repo. The install appears successful.
- After sourcing oneAPI with
source /opt/intel/oneapi/setvars.sh
and setting myLD_LIBRARY_PATH
to point to the pip'slib
folder as well, the sanity test from the install instructions completes successfully with a warning (see below).
Simple P2P Check
I then tried to run a simple P2P check to measure bandwidth between devices:
#!/usr/bin/env python
import os
import sys
import time
import torch
import torch.distributed as dist
import intel_extension_for_pytorch as ipex
import oneccl_bindings_for_pytorch as torch_ccl
def get_device():
return 'xpu:%s' % (dist.get_rank() % torch.xpu.device_count(),)
def get_rank_from_env():
if 'PMI_RANK' in os.environ:
return os.environ['PMI_RANK']
elif 'PMIX_RANK' in os.environ:
return os.environ['PMIX_RANK']
elif 'RANK' in os.environ:
return os.environ['RANK']
else:
raise Exception('Error: neither \'PMI_RANK\' nor \'RANK\' environment variable found. Are you invoking this script using mpirun or torchrun?')
def get_nprocs_from_env():
if 'PMI_SIZE' in os.environ:
return os.environ['PMI_SIZE']
elif 'WORLD_SIZE' in os.environ:
return os.environ['WORLD_SIZE']
else:
raise Exception('Error: neither \'PMI_SIZE\' nor \'WORLD_SIZE\' environment variable found. Are you invoking this script using mpirun or torchrun?')
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
os.environ["RANK"] = get_rank_from_env()
os.environ["WORLD_SIZE"] = get_nprocs_from_env()
dist.init_process_group(backend="ccl", init_method="env://")
nbytes = 1024*1024*1024
n = nbytes // 4
nbytes = n * 4
gbytes = nbytes * 1e-9
print('Process %s/%s using device %s' % (dist.get_rank(), dist.get_world_size(), get_device()))
send_tensor = torch.zeros(n, dtype=torch.float32, device=get_device())
recv_tensor = torch.zeros(n, dtype=torch.float32, device=get_device())
# Perform an all_reduce to initialize communicators and such.
dist.all_reduce(send_tensor)
if dist.get_rank() == 0:
print('Benchmarking P2P...')
for send_rank in range(dist.get_world_size()):
for recv_rank in range(dist.get_world_size()):
if send_rank != recv_rank:
dist.barrier()
if dist.get_rank() == send_rank:
print('Send %s -> %s' % (send_rank, recv_rank))
dist.barrier()
begin = time.time()
reqs = []
if dist.get_rank() == send_rank:
req = dist.isend(send_tensor, recv_rank)
reqs.append(req)
if dist.get_rank() == recv_rank:
req = dist.irecv(recv_tensor, send_rank)
reqs.append(req)
for req in reqs:
req.wait()
end = time.time()
duration = end - begin
if dist.get_rank() == recv_rank:
print('%s -> %s took %s s, achieved %s GB/s' % (send_rank, recv_rank, duration, gbytes / duration))
The output is as follows (removing the ATen warning previously mentioned):
(ipex-new) bbrock@hedp030:~/src/ai/pytorch-gemm/ccl> cat out.dat
My guessed rank = 4
My guessed rank = 0
My guessed rank = 1
My guessed rank = 2
My guessed rank = 3
My guessed rank = 5
My guessed rank = 6
My guessed rank = 7
Process 6/8 using device xpu:6
Process 4/8 using device xpu:4
Process 5/8 using device xpu:5
Process 3/8 using device xpu:3
Process 7/8 using device xpu:7
Process 2/8 using device xpu:2
Process 0/8 using device xpu:0
Process 1/8 using device xpu:1
Benchmarking P2P...
Send 0 -> 1
0 -> 1 took 0.30544233322143555 s, achieved 3.51536675573249 GB/s
Send 0 -> 2
-
It blocks indefinitely on the send from 0 -> 2.
-
The bandwidth is way lower than expected. It reports 3.5 GB/s, when it should be >150 GB/s (over MDFI between tiles 0 and 1, and Xe Link would be 20 GB/s).
My GPUs on the system appear to be configured correctly:
(ipex-new) bbrock@hedp030:~/src/ai/pytorch-gemm/ccl> sycl-ls
[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2024.17.3.0.08_160000]
[opencl:cpu:1] Intel(R) OpenCL, Intel(R) Xeon(R) Platinum 8480+ OpenCL 3.0 (Build 0) [2024.17.3.0.08_160000]
[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Data Center GPU Max 1550 OpenCL 3.0 NEO [25.05.32567]
[opencl:gpu:3] Intel(R) OpenCL Graphics, Intel(R) Data Center GPU Max 1550 OpenCL 3.0 NEO [25.05.32567]
[opencl:gpu:4] Intel(R) OpenCL Graphics, Intel(R) Data Center GPU Max 1550 OpenCL 3.0 NEO [25.05.32567]
[opencl:gpu:5] Intel(R) OpenCL Graphics, Intel(R) Data Center GPU Max 1550 OpenCL 3.0 NEO [25.05.32567]
[opencl:gpu:6] Intel(R) OpenCL Graphics, Intel(R) Data Center GPU Max 1550 OpenCL 3.0 NEO [25.05.32567]
[opencl:gpu:7] Intel(R) OpenCL Graphics, Intel(R) Data Center GPU Max 1550 OpenCL 3.0 NEO [25.05.32567]
[opencl:gpu:8] Intel(R) OpenCL Graphics, Intel(R) Data Center GPU Max 1550 OpenCL 3.0 NEO [25.05.32567]
[opencl:gpu:9] Intel(R) OpenCL Graphics, Intel(R) Data Center GPU Max 1550 OpenCL 3.0 NEO [25.05.32567]
[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1550 1.6 [1.3.32567]
[ext_oneapi_level_zero:gpu:1] Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1550 1.6 [1.3.32567]
[ext_oneapi_level_zero:gpu:2] Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1550 1.6 [1.3.32567]
[ext_oneapi_level_zero:gpu:3] Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1550 1.6 [1.3.32567]
[ext_oneapi_level_zero:gpu:4] Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1550 1.6 [1.3.32567]
[ext_oneapi_level_zero:gpu:5] Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1550 1.6 [1.3.32567]
[ext_oneapi_level_zero:gpu:6] Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1550 1.6 [1.3.32567]
[ext_oneapi_level_zero:gpu:7] Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1550 1.6 [1.3.32567]
(ipex-new) bbrock@hedp030:~/src/ai/pytorch-gemm/ccl> xpu-smi topology -m
GPU 0/0 GPU 0/1 GPU 1/0 GPU 1/1 GPU 2/0 GPU 2/1 GPU 3/0 GPU 3/1 CPU Affinity
GPU 0/0 S MDF XL* XL8 XL8 XL* XL8 XL* 0-55,112-167
GPU 0/1 MDF S XL8 XL* XL* XL8 XL* XL8 0-55,112-167
GPU 1/0 XL* XL8 S MDF XL* XL8 XL* XL8 0-55,112-167
GPU 1/1 XL8 XL* MDF S XL8 XL* XL8 XL* 0-55,112-167
GPU 2/0 XL8 XL* XL* XL8 S MDF XL8 XL* 56-111,168-223
GPU 2/1 XL* XL8 XL8 XL* MDF S XL* XL8 56-111,168-223
GPU 3/0 XL8 XL* XL* XL8 XL8 XL* S MDF 56-111,168-223
GPU 3/1 XL* XL8 XL8 XL* XL* XL8 MDF S 56-111,168-223
Please advise on what to do. I get the same results whether using the mpirun
bundled with pip
or the system's Intel MPI.
Sanity Check Warning
The warning produced by the sanity check after install is about ATen op registration. I saw someone was told in another issue that this warning can be ignored, so I'm ignoring it and assuming this is successful.
(ipex-new) bbrock@hedp030:~/src/ai/pytorch-gemm/ccl> !pyth
python3 -c "import torch; import intel_extension_for_pytorch as ipex; print(torch.__version__); print(ipex.__version__); [print(f'[{i}]: {torch.xpu.get_device_properties(i)}') for i in range(torch.xpu.device_count())];"
[W416 15:51:01.385586326 OperatorEntry.cpp:154] Warning: Warning only once for all operators, other operators may also be overridden.
Overriding a previously registered kernel for the same operator and the same dispatch key
operator: aten::_validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> ()
registered at /pytorch/build/aten/src/ATen/RegisterSchema.cpp:6
dispatch key: XPU
previous kernel: registered at /pytorch/build/aten/src/ATen/RegisterCPU.cpp:30477
new kernel: registered at /build/intel-pytorch-extension/build/Release/csrc/gpu/csrc/aten/generated/ATen/RegisterXPU.cpp:468 (function operator())
2.6.0+xpu
2.6.10+xpu
[0]: _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.6.32567+18', total_memory=65536MB, max_compute_units=512, gpu_eu_count=512, gpu_subslice_count=64, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
[1]: _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.6.32567+18', total_memory=65536MB, max_compute_units=512, gpu_eu_count=512, gpu_subslice_count=64, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
[2]: _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.6.32567+18', total_memory=65536MB, max_compute_units=512, gpu_eu_count=512, gpu_subslice_count=64, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
[3]: _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.6.32567+18', total_memory=65536MB, max_compute_units=512, gpu_eu_count=512, gpu_subslice_count=64, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
[4]: _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.6.32567+18', total_memory=65536MB, max_compute_units=512, gpu_eu_count=512, gpu_subslice_count=64, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
[5]: _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.6.32567+18', total_memory=65536MB, max_compute_units=512, gpu_eu_count=512, gpu_subslice_count=64, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
[6]: _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.6.32567+18', total_memory=65536MB, max_compute_units=512, gpu_eu_count=512, gpu_subslice_count=64, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
[7]: _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.6.32567+18', total_memory=65536MB, max_compute_units=512, gpu_eu_count=512, gpu_subslice_count=64, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
[W416 15:51:04.037020987 OperatorEntry.cpp:154] Warning: Warning only once for all operators, other operators may also be overridden.
Overriding a previously registered kernel for the same operator and the same dispatch key
operator: aten::_validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> ()
registered at /pytorch/build/aten/src/ATen/RegisterSchema.cpp:6
dispatch key: XPU
previous kernel: registered at /pytorch/build/aten/src/ATen/RegisterCPU.cpp:30477
new kernel: registered at /build/intel-pytorch-extension/build/Release/csrc/gpu/csrc/aten/generated/ATen/RegisterXPU.cpp:468 (function operator())