Skip to content

Performance regression with IPEX 2.3, TORCH 2.6 compared with IPEX 2.1 #768

Open
@gc-fu

Description

@gc-fu

Describe the issue

We encountered a performance regression issue that we think might be related to intel-extension-for-pytorch.
Specifically, we found that the performance of the gemm_kernel is inconsistent across the following setups:

  • Torch 2.1 + IPEX 2.1 + oneAPI 20241
  • Torch 2.3 + IPEX 2.3 + oneAPI 2024.2
  • Torch 2.6 + oneAPI 2025.0

Below are the results of our testing:

image (1)

The test script we used is listed as follows:

import torch
import intel_extension_for_pytorch
import torch.nn as nn
import time
import tqdm


def bench_linear(M, K, N, warm_up, iter_num, device, dtype):
    input_tensor = torch.randn(M, K, device=device, dtype=dtype)
    linear = nn.Linear(
        in_features=K, out_features=N, bias=False, device=device, dtype=dtype
    )
    with torch.no_grad():
        linear.weight.copy_(torch.randn(N, K))

    total_time = 0
    for i in tqdm.tqdm(range(warm_up + iter_num)):
        torch.xpu.synchronize()
        st = time.time()
        output = linear(input_tensor)
        torch.xpu.synchronize()
        et = time.time()
        ###
        if i >= warm_up:
            total_time += (et - st) * 1000

    avg_latency = total_time / iter_num
    tflops = (2 * M * K * N + 3 * M * N) / avg_latency / 1e12 * 1000
    print(
        f"Shape: {M}x{K}:{K}x{N}, Data Type:{dtype}, TFLOPS: {tflops:.2f}, Avg Latency: {avg_latency:.2f}"
    )


if __name__ == "__main__":

    print(torch.__config__.show())
    matrix_sizes = [
        (3000, 3584, 4608),
        (3000, 3584, 3584),
        (3000, 3584, 18944),
        (256, 3584, 4096),
        (256, 3584, 512),
        (2048, 3584, 4608),
        (2048, 3584, 3584),
        (2048, 3584, 37888),
        (2048, 18944, 3584),
    ]
    device = torch.device("xpu")
    dtype = torch.float16  # * 87.91 tflops
    # dtype = torch.float32 # * 16.85 tflops

    for size in matrix_sizes:
        M, K, N = size
        bench_linear(M, K, N, 30, 1000, device, dtype)

The performance is nearly half for some specific shapes.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions