Skip to content

torch.linalg.eigh is significantly slower than expected on Max Series GPU #439

Open
@ogrisel

Description

@ogrisel

Describe the issue

Similarly to #428, I tried torch.linalg.eigh on a Max Series GPU using the Intel Devcloud and packages from the intel conda channel, the performance on XPU is not much better than on CPU:

>>> import intel_extension_for_pytorch
>>> import torch
>>> intel_extension_for_pytorch.__version__
'2.0.110+xpu'
>>> torch.__version__
'2.0.1a0+cxx11.abi'
>>> X = torch.randn(500, 500)
>>> X_xpu = X.to("xpu")
>>> %time C = X.T @ X
CPU times: user 938 ms, sys: 76.8 ms, total: 1.01 s
Wall time: 115 ms
>>> %time C_xpu = X_xpu.T @ X_xpu
CPU times: user 4.37 ms, sys: 4 µs, total: 4.37 ms
Wall time: 4.21 ms

So GEMM is around 20x faster on the XPU device that on the CPU host.

However, torch.linalg.eigh is not faster when using the XPU, which is quite unexpected given the speed difference for GEMM.

>>> %time _ = torch.linalg.eigh(C)
CPU times: user 2min 30s, sys: 10.2 s, total: 2min 40s
Wall time: 6.89 s
>>> %time _ = torch.linalg.eigh(C_xpu)
CPU times: user 4min 1s, sys: 14.5 s, total: 4min 15s
Wall time: 5.52 s

More information about the runtime environment of this session:

>>> from pprint import pprint
>>> pprint(dpctl.get_devices())
[<dpctl.SyclDevice [backend_type.opencl, device_type.cpu,  Intel(R) Xeon(R) Platinum 8480+] at 0x1472aac521f0>,
 <dpctl.SyclDevice [backend_type.opencl, device_type.accelerator,  Intel(R) FPGA Emulation Device] at 0x1472a80a9ef0>,
 <dpctl.SyclDevice [backend_type.level_zero, device_type.gpu,  Intel(R) Data Center GPU Max 1100] at 0x1472a80a9df0>]
>>> import joblib
>>> joblib.cpu_count(only_physical_cores=True)
112
>>> import threadpoolctl
>>> pprint(threadpoolctl.threadpool_info())
[{'filepath': '/home/u103854/mambaforge/envs/intel/lib/libmkl_rt.so.2',
  'internal_api': 'mkl',
  'num_threads': 112,
  'prefix': 'libmkl_rt',
  'threading_layer': 'intel',
  'user_api': 'blas',
  'version': '2023.2-Product'},
 {'filepath': '/home/u103854/mambaforge/envs/intel/lib/libiomp5.so',
  'internal_api': 'openmp',
  'num_threads': 112,
  'prefix': 'libiomp',
  'user_api': 'openmp',
  'version': None},
 {'filepath': '/home/u103854/mambaforge/envs/intel/lib/libgomp.so.1.0.0',
  'internal_api': 'openmp',
  'num_threads': 112,
  'prefix': 'libgomp',
  'user_api': 'openmp',
  'version': None}]

Furthermore, all those numbers are extremely slow for such a small dataset.

Here is the output of a similar experiment on my local laptop (Apple M1):

>>> import torch
>>> X = torch.randn(500, 500)
>>> %time C = X.T @ X
CPU times: user 247 µs, sys: 718 µs, total: 965 µs
Wall time: 4.5 ms
>>> %time _ = torch.linalg.eigh(C)
CPU times: user 12.3 ms, sys: 6.88 ms, total: 19.2 ms
Wall time: 20.6 ms

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions